async def test_push_add(self): job_name = "my-job" p = pusher.Pusher(job_name, self.server.url) registry = Registry() counter = Counter("counter_test", "A counter.", {"type": "counter"}) registry.register(counter) counter_data = (({"c_sample": "1", "c_subsample": "b"}, 400),) [counter.set(c[0], c[1]) for c in counter_data] # TextFormatter expected result valid_result = ( b"# HELP counter_test A counter.\n" b"# TYPE counter_test counter\n" b'counter_test{c_sample="1",c_subsample="b",type="counter"} 400\n' ) # BinaryFormatter expected result # valid_result = (b'[\n\x0ccounter_test\x12\nA counter.\x18\x00"=\n\r' # b'\n\x08c_sample\x12\x011\n\x10\n\x0bc_subsample\x12' # b'\x01b\n\x0f\n\x04type\x12\x07counter\x1a\t\t\x00' # b'\x00\x00\x00\x00\x00y@') # Push to the pushgateway resp = await p.add(registry) self.assertEqual(resp.status, 200) self.assertEqual(expected_job_path(job_name), self.server.test_results["path"]) self.assertEqual("POST", self.server.test_results["method"]) self.assertEqual(valid_result, self.server.test_results["body"])
def test_registry_marshall_summary(self): metric_name = "summary_test" metric_help = "A summary." # metric_data = ( # ({'s_sample': '1', 's_subsample': 'b'}, # {0.5: 4235.0, 0.9: 4470.0, 0.99: 4517.0, 'count': 22, 'sum': 98857.0}), # ) summary_data = (({"s_sample": "1", "s_subsample": "b"}, range(4000, 5000, 47)),) summary = Summary(metric_name, metric_help, const_labels={"type": "summary"}) for labels, values in summary_data: for v in values: summary.add(labels, v) registry = Registry() registry.register(summary) valid_result = ( b"\x99\x01\n\x0csummary_test\x12\nA summary." b'\x18\x02"{\n\r\n\x08s_sample\x12\x011\n\x10\n' b"\x0bs_subsample\x12\x01b\n\x0f\n\x04type\x12\x07" b'summary"G\x08\x16\x11\x00\x00\x00\x00\x90"\xf8@' b"\x1a\x12\t\x00\x00\x00\x00\x00\x00\xe0?\x11\x00" b"\x00\x00\x00\x00\x8b\xb0@\x1a\x12\t\xcd\xcc\xcc" b"\xcc\xcc\xcc\xec?\x11\x00\x00\x00\x00\x00v\xb1@" b"\x1a\x12\t\xaeG\xe1z\x14\xae\xef?\x11\x00\x00\x00" b"\x00\x00\xa5\xb1@" ) f = binary.BinaryFormatter() self.assertEqual(valid_result, f.marshall(registry))
async def test_grouping_key_with_empty_value(self): # See https://github.com/prometheus/pushgateway/blob/master/README.md#url # for encoding rules. job_name = "example" p = pusher.Pusher( job_name, self.server.url, grouping_key={ "first": "", "second": "foo" }, ) registry = Registry() c = Counter("example_total", "Total examples", {}) registry.register(c) c.inc({}) # Push to the pushgateway resp = await p.replace(registry) self.assertEqual(resp.status, 200) self.assertEqual( "/metrics/job/example/first@base64/=/second/foo", self.server.test_results["path"], )
async def test_push_add(self): job_name = "my-job" p = Pusher(job_name, TEST_URL) registry = Registry() counter = Counter("counter_test", "A counter.", {'type': "counter"}) registry.register(counter) counter_data = ( ({'c_sample': '1', 'c_subsample': 'b'}, 400), ) [counter.set(c[0], c[1]) for c in counter_data] valid_result = (b'[\n\x0ccounter_test\x12\nA counter.\x18\x00"=\n\r' b'\n\x08c_sample\x12\x011\n\x10\n\x0bc_subsample\x12' b'\x01b\n\x0f\n\x04type\x12\x07counter\x1a\t\t\x00' b'\x00\x00\x00\x00\x00y@') # Push to the pushgateway resp = await p.add(registry) self.assertEqual(resp.status, 200) self.assertEqual( expected_job_path(job_name), self.server.test_results['path']) self.assertEqual("POST", self.server.test_results['method']) self.assertEqual(valid_result, self.server.test_results['body'])
async def test_push_add(self): job_name = "my-job" p = pusher.Pusher(job_name, self.server.url) registry = Registry() counter = Counter("counter_test", "A counter.", {"type": "counter"}) registry.register(counter) counter_data = (({"c_sample": "1", "c_subsample": "b"}, 400), ) [counter.set(c[0], c[1]) for c in counter_data] # TextFormatter expected result valid_result = ( b"# HELP counter_test A counter.\n" b"# TYPE counter_test counter\n" b'counter_test{c_sample="1",c_subsample="b",type="counter"} 400\n') # BinaryFormatter expected result # valid_result = (b'[\n\x0ccounter_test\x12\nA counter.\x18\x00"=\n\r' # b'\n\x08c_sample\x12\x011\n\x10\n\x0bc_subsample\x12' # b'\x01b\n\x0f\n\x04type\x12\x07counter\x1a\t\t\x00' # b'\x00\x00\x00\x00\x00y@') # Push to the pushgateway resp = await p.add(registry) self.assertEqual(resp.status, 200) self.assertEqual(expected_job_path(job_name), self.server.test_results["path"]) self.assertEqual("POST", self.server.test_results["method"]) self.assertEqual(valid_result, self.server.test_results["body"])
def __init__(self, bot: commands.Bot): self.bot = bot self.registry = Registry() self.service = Service(self.registry) self.events = Counter("events", "Discord API event counts.") self.registry.register(self.events) self.latency = Histogram("latency", "Discord API latency.") self.registry.register(self.latency) self.gc_started: typing.Optional[float] = None self.gc_latency = Histogram( "gc_latency", "CPython garbage collector execution times." ) self.registry.register(self.gc_latency) self.gc_stats = Counter("gc_stats", "CPython garbage collector stats.") self.registry.register(self.gc_stats) self.process = psutil.Process(os.getpid()) self.resources = Gauge("resources", "Process resource usage gauges.") self.registry.register(self.resources) self.hook_gc() self.update_gc_and_resource_stats.start() # pylint: disable=no-member self.serve.start() # pylint: disable=no-member self.update_latency.start() # pylint: disable=no-member
def test_registry_marshall_summary(self): format_times = 10 summary_data = (({ 's_sample': '1', 's_subsample': 'b' }, range(4000, 5000, 47)), ) registry = Registry() summary = Summary("summary_test", "A summary.", {'type': "summary"}) # Add data [summary.add(i[0], s) for i in summary_data for s in i[1]] registry.register(summary) valid_result = (b'\x99\x01\n\x0csummary_test\x12\nA summary.' b'\x18\x02"{\n\r\n\x08s_sample\x12\x011\n\x10\n' b'\x0bs_subsample\x12\x01b\n\x0f\n\x04type\x12\x07' b'summary"G\x08\x16\x11\x00\x00\x00\x00\x90"\xf8@' b'\x1a\x12\t\x00\x00\x00\x00\x00\x00\xe0?\x11\x00' b'\x00\x00\x00\x00\x8b\xb0@\x1a\x12\t\xcd\xcc\xcc' b'\xcc\xcc\xcc\xec?\x11\x00\x00\x00\x00\x00v\xb1@' b'\x1a\x12\t\xaeG\xe1z\x14\xae\xef?\x11\x00\x00\x00' b'\x00\x00\xa5\xb1@') f = BinaryFormatter() # Check multiple times to ensure multiple marshalling requests for i in range(format_times): self.assertEqual(valid_result, f.marshall(registry))
async def test_push_job_ping(self): job_name = "my-job" p = pusher.Pusher(job_name, self.server.url, loop=self.loop) registry = Registry() c = Counter("total_requests", "Total requests.", {}) registry.register(c) c.inc({"url": "/p/user"}) # Push to the pushgateway resp = await p.replace(registry) self.assertEqual(resp.status, 200) self.assertEqual(expected_job_path(job_name), self.server.test_results["path"])
def test_registry_marshall_histogram(self): """ check encode of histogram matches expected output """ metric_name = "histogram_test" metric_help = "A histogram." metric_data = (( { "h_sample": "1", "h_subsample": "b" }, { 5.0: 3, 10.0: 2, 15.0: 1, "count": 6, "sum": 46.0 }, ), ) histogram_data = (({ "h_sample": "1", "h_subsample": "b" }, (4.5, 5.0, 4.0, 9.6, 9.0, 13.9)), ) POS_INF = float("inf") histogram = Histogram( metric_name, metric_help, const_labels={"type": "histogram"}, buckets=(5.0, 10.0, 15.0, POS_INF), ) for labels, values in histogram_data: for v in values: histogram.add(labels, v) registry = Registry() registry.register(histogram) valid_result = (b"\x97\x01\n\x0ehistogram_test\x12\x0cA histogram." b'\x18\x04"u\n\r\n\x08h_sample\x12\x011\n\x10\n' b"\x0bh_subsample\x12\x01b\n\x11\n\x04type\x12\t" b"histogram:?\x08\x06\x11\x00\x00\x00\x00\x00\x00G@" b"\x1a\x0b\x08\x03\x11\x00\x00\x00\x00\x00\x00\x14@" b"\x1a\x0b\x08\x05\x11\x00\x00\x00\x00\x00\x00$@\x1a" b"\x0b\x08\x06\x11\x00\x00\x00\x00\x00\x00.@\x1a\x0b" b"\x08\x06\x11\x00\x00\x00\x00\x00\x00\xf0\x7f") f = binary.BinaryFormatter() self.assertEqual(valid_result, f.marshall(registry))
def configure_prometheus_metrics_exporter(app: Starlette): app.add_middleware(MetricsMiddleware, webapp=app) app.registry = Registry() const_labels = { "host": socket.gethostname(), "name": "service1", "version": "1" } app.counter_gauge = Gauge("counter", "Current count.", const_labels=const_labels) app.registry.register(app.counter_gauge) app.svc_requests_total = Counter("svc_requests_total", "Count of service HTTP requests", const_labels=const_labels) app.registry.register(app.svc_requests_total) app.svc_responses_total = Counter("svc_responses_total", "Count of service HTTP responses", const_labels=const_labels) app.registry.register(app.svc_responses_total) app.svc_internal_error_total = Counter( "svc_internal_error_total", "Histogram of internal errors by method, path and type of error", const_labels=const_labels) app.registry.register(app.svc_internal_error_total)
async def test_invalid_registry(self): """ check only valid registry can be provided """ for invalid_registry in ["nope", dict(), list()]: with self.assertRaises(Exception) as cm: Service(registry=invalid_registry) self.assertIn("registry must be a Registry, got:", str(cm.exception)) Service(registry=Registry())
async def collect(self, target): self.registry = Registry() await asyncio.gather(*[ self._collect_strategy(target, strategy) for strategy in self.STRATEGIES ]) return self.registry
async def test_invalid_registry(self): ''' check only valid registry can be provided ''' for invalid_registry in ['nope', dict(), list()]: with self.assertRaises(Exception) as cm: Service(registry=invalid_registry, loop=self.loop) self.assertIn('registry must be a Registry, got:', str(cm.exception)) Service(registry=Registry(), loop=self.loop)
def test_registry_marshall_gauge(self): gauge_data = (({"g_sample": "1", "g_subsample": "b"}, 800), ) gauge = Gauge("gauge_test", "A gauge.", const_labels={"type": "gauge"}) for labels, value in gauge_data: gauge.set(labels, value) registry = Registry() registry.register(gauge) valid_result = (b'U\n\ngauge_test\x12\x08A gauge.\x18\x01";' b"\n\r\n\x08g_sample\x12\x011\n\x10\n\x0bg_subsample" b"\x12\x01b\n\r\n\x04type\x12\x05gauge\x12\t\t\x00" b"\x00\x00\x00\x00\x00\x89@") f = binary.BinaryFormatter() self.assertEqual(valid_result, f.marshall(registry))
def test_registry_marshall_counter(self): counter_data = (({"c_sample": "1", "c_subsample": "b"}, 400), ) counter = Counter("counter_test", "A counter.", const_labels={"type": "counter"}) for labels, value in counter_data: counter.set(labels, value) registry = Registry() registry.register(counter) valid_result = (b'[\n\x0ccounter_test\x12\nA counter.\x18\x00"=\n\r' b"\n\x08c_sample\x12\x011\n\x10\n\x0bc_subsample\x12" b"\x01b\n\x0f\n\x04type\x12\x07counter\x1a\t\t\x00\x00" b"\x00\x00\x00\x00y@") f = binary.BinaryFormatter() self.assertEqual(valid_result, f.marshall(registry))
async def test_grouping_key(self): # See https://github.com/prometheus/pushgateway/blob/master/README.md#url # for encoding rules. job_name = "my-job" p = pusher.Pusher( job_name, self.server.url, grouping_key={"instance": "127.0.0.1:1234"}, ) registry = Registry() c = Counter("total_requests", "Total requests.", {}) registry.register(c) c.inc({}) # Push to the pushgateway resp = await p.replace(registry) self.assertEqual(resp.status, 200) self.assertEqual( "/metrics/job/my-job/instance/127.0.0.1:1234", self.server.test_results["path"], )
def test_registry_marshall_counter(self): format_times = 10 counter_data = (({'c_sample': '1', 'c_subsample': 'b'}, 400), ) registry = Registry() counter = Counter("counter_test", "A counter.", {'type': "counter"}) # Add data [counter.set(c[0], c[1]) for c in counter_data] registry.register(counter) valid_result = (b'[\n\x0ccounter_test\x12\nA counter.\x18\x00"=\n\r' b'\n\x08c_sample\x12\x011\n\x10\n\x0bc_subsample\x12' b'\x01b\n\x0f\n\x04type\x12\x07counter\x1a\t\t\x00\x00' b'\x00\x00\x00\x00y@') f = BinaryFormatter() # Check multiple times to ensure multiple marshalling requests for i in range(format_times): self.assertEqual(valid_result, f.marshall(registry))
def test_registry_marshall_gauge(self): format_times = 10 gauge_data = (({'g_sample': '1', 'g_subsample': 'b'}, 800), ) registry = Registry() gauge = Gauge("gauge_test", "A gauge.", {'type': "gauge"}) # Add data [gauge.set(g[0], g[1]) for g in gauge_data] registry.register(gauge) valid_result = (b'U\n\ngauge_test\x12\x08A gauge.\x18\x01";' b'\n\r\n\x08g_sample\x12\x011\n\x10\n\x0bg_subsample' b'\x12\x01b\n\r\n\x04type\x12\x05gauge\x12\t\t\x00' b'\x00\x00\x00\x00\x00\x89@') f = BinaryFormatter() # Check multiple times to ensure multiple marshalling requests for i in range(format_times): self.assertEqual(valid_result, f.marshall(registry))
async def test_push_add(self): job_name = "my-job" p = Pusher(job_name, TEST_URL) registry = Registry() counter = Counter("counter_test", "A counter.", {'type': "counter"}) registry.register(counter) counter_data = (({'c_sample': '1', 'c_subsample': 'b'}, 400), ) [counter.set(c[0], c[1]) for c in counter_data] valid_result = (b'[\n\x0ccounter_test\x12\nA counter.\x18\x00"=\n\r' b'\n\x08c_sample\x12\x011\n\x10\n\x0bc_subsample\x12' b'\x01b\n\x0f\n\x04type\x12\x07counter\x1a\t\t\x00' b'\x00\x00\x00\x00\x00y@') # Push to the pushgateway resp = await p.add(registry) self.assertEqual(resp.status, 200) self.assertEqual(expected_job_path(job_name), self.server.test_results['path']) self.assertEqual("POST", self.server.test_results['method']) self.assertEqual(valid_result, self.server.test_results['body'])
async def test_grouping_key_with_value_containing_slash(self): # See https://github.com/prometheus/pushgateway/blob/master/README.md#url # for encoding rules. job_name = "directory_cleaner" p = pusher.Pusher( job_name, self.server.url, grouping_key={"path": "/var/tmp"}, ) registry = Registry() c = Counter("exec_total", "Total executions", {}) registry.register(c) c.inc({}) # Push to the pushgateway resp = await p.replace(registry) self.assertEqual(resp.status, 200) # Generated base64 content include '=' as padding. self.assertEqual( "/metrics/job/directory_cleaner/path@base64/L3Zhci90bXA=", self.server.test_results["path"], )
class Metrics(commands.Cog): def __init__(self, bot: commands.Bot): self.bot = bot self.registry = Registry() self.service = Service(self.registry) self.events = Counter("events", "Discord API event counts.") self.registry.register(self.events) self.latency = Histogram("latency", "Discord API latency.") self.registry.register(self.latency) self.gc_started: typing.Optional[float] = None self.gc_latency = Histogram( "gc_latency", "CPython garbage collector execution times." ) self.registry.register(self.gc_latency) self.gc_stats = Counter("gc_stats", "CPython garbage collector stats.") self.registry.register(self.gc_stats) self.process = psutil.Process(os.getpid()) self.resources = Gauge("resources", "Process resource usage gauges.") self.registry.register(self.resources) self.hook_gc() self.update_gc_and_resource_stats.start() # pylint: disable=no-member self.serve.start() # pylint: disable=no-member self.update_latency.start() # pylint: disable=no-member def gc_callback(self, phase: str, info: typing.Mapping[str, int]): if phase == "start": self.gc_started = time.time() else: self.gc_latency.observe( {"generation": info["generation"]}, time.time() - self.gc_started ) def hook_gc(self): gc.callbacks.append(self.gc_callback) def unhook_gc(self): gc.callbacks.remove(self.gc_callback) @tasks.loop(minutes=1) async def update_gc_and_resource_stats(self): # gc stats for gen, stats in zip(itertools.count(), gc.get_stats()): for stat, value in stats.items(): self.gc_stats.set({"generation": gen, "type": stat}, value) # process resource usage for key, value in self.process.cpu_times()._asdict().items(): self.resources.set({"type": f"cpu_{key}"}, value) for key, value in self.process.memory_info()._asdict().items(): self.resources.set({"type": f"mem_{key}"}, value) for key, value in self.process.io_counters()._asdict().items(): self.resources.set({"type": f"io_{key}"}, value) self.resources.set({"type": "num_threads"}, self.process.num_threads()) self.resources.set({"type": "num_fds"}, self.process.num_fds()) @tasks.loop(count=1, reconnect=False) async def serve(self): await self.service.start(port=9100) logging.info("Serving Prometheus metrics on: %s", self.service.metrics_url) @tasks.loop(minutes=1) async def update_latency(self): self.latency.observe({"type": "seconds"}, self.bot.latency) @update_latency.before_loop async def before_update_latency(self): await self.bot.wait_until_ready() def cog_unload(self): self.unhook_gc() self.update_gc_and_resource_stats.cancel() # pylint: disable=no-member self.serve.cancel() # pylint: disable=no-member self.update_latency.cancel() # pylint: disable=no-member @commands.Cog.listener() async def on_connect(self): self.events.inc({"type": "connect"}) @commands.Cog.listener() async def on_shard_connect(self, shard_id): self.events.inc({"type": f"shard_connect_{shard_id}"}) @commands.Cog.listener() async def on_disconnect(self): self.events.inc({"type": "disconnect"}) @commands.Cog.listener() async def on_shard_disconnect(self, shard_id): self.events.inc({"type": f"shard_disconnect_{shard_id}"}) @commands.Cog.listener() async def on_ready(self): self.events.inc({"type": "ready"}) @commands.Cog.listener() async def on_shard_ready(self, shard_id): self.events.inc({"type": f"shard_ready_{shard_id}"}) @commands.Cog.listener() async def on_resumed(self): self.events.inc({"type": "resumed"}) @commands.Cog.listener() async def on_shard_resumed(self, shard_id): self.events.inc({"type": f"shard_resumed_{shard_id}"}) @commands.Cog.listener() async def on_error(self, event, *_): self.events.inc({"type": f"error_{event}"}) @commands.Cog.listener() async def on_socket_raw_receive(self, *_): self.events.inc({"type": "socket_raw_receive"}) @commands.Cog.listener() async def on_socket_raw_send(self, *_): self.events.inc({"type": "socket_raw_send"}) @commands.Cog.listener() async def on_typing(self, *_): self.events.inc({"type": "typing"}) @commands.Cog.listener() async def on_message(self, *_): self.events.inc({"type": "message"}) @commands.Cog.listener() async def on_message_delete(self, *_): self.events.inc({"type": "message_delete"}) @commands.Cog.listener() async def on_bulk_message_delete(self, *_): self.events.inc({"type": "bulk_message_delete"}) @commands.Cog.listener() async def on_raw_message_delete(self, *_): self.events.inc({"type": "raw_message_delete"}) @commands.Cog.listener() async def on_raw_bulk_message_delete(self, *_): self.events.inc({"type": "raw_bulk_message_delete"}) @commands.Cog.listener() async def on_message_edit(self, *_): self.events.inc({"type": "message_edit"}) @commands.Cog.listener() async def on_raw_message_edit(self, *_): self.events.inc({"type": "raw_message_edit"}) @commands.Cog.listener() async def on_reaction_add(self, *_): self.events.inc({"type": "reaction_add"}) @commands.Cog.listener() async def on_raw_reaction_add(self, *_): self.events.inc({"type": "raw_reaction_add"}) @commands.Cog.listener() async def on_reaction_remove(self, *_): self.events.inc({"type": "reaction_remove"}) @commands.Cog.listener() async def on_raw_reaction_remove(self, *_): self.events.inc({"type": "raw_reaction_remove"}) @commands.Cog.listener() async def on_reaction_clear(self, *_): self.events.inc({"type": "reaction_clear"}) @commands.Cog.listener() async def on_raw_reaction_clear(self, *_): self.events.inc({"type": "raw_reaction_clear"}) @commands.Cog.listener() async def on_reaction_clear_emoji(self, *_): self.events.inc({"type": "reaction_clear_emoji"}) @commands.Cog.listener() async def on_raw_reaction_clear_emoji(self, *_): self.events.inc({"type": "raw_reaction_clear_emoji"}) @commands.Cog.listener() async def on_private_channel_delete(self, *_): self.events.inc({"type": "private_channel_delete"}) @commands.Cog.listener() async def on_private_channel_create(self, *_): self.events.inc({"type": "private_channel_create"}) @commands.Cog.listener() async def on_private_channel_update(self, *_): self.events.inc({"type": "private_channel_update"}) @commands.Cog.listener() async def on_private_channel_pins_update(self, *_): self.events.inc({"type": "private_channel_pins_update"}) @commands.Cog.listener() async def on_guild_channel_delete(self, *_): self.events.inc({"type": "guild_channel_delete"}) @commands.Cog.listener() async def on_guild_channel_create(self, *_): self.events.inc({"type": "guild_channel_create"}) @commands.Cog.listener() async def on_guild_channel_update(self, *_): self.events.inc({"type": "guild_channel_update"}) @commands.Cog.listener() async def on_guild_channel_pins_update(self, *_): self.events.inc({"type": "guild_channel_pins_update"}) @commands.Cog.listener() async def on_guild_channel_integrations_update(self, *_): self.events.inc({"type": "guild_channel_integrations_update"}) @commands.Cog.listener() async def on_webhooks_update(self, *_): self.events.inc({"type": "webhooks_update"}) @commands.Cog.listener() async def on_member_join(self, *_): self.events.inc({"type": "member_join"}) @commands.Cog.listener() async def on_member_remove(self, *_): self.events.inc({"type": "member_remove"}) @commands.Cog.listener() async def on_member_update(self, *_): self.events.inc({"type": "member_update"}) @commands.Cog.listener() async def on_user_update(self, *_): self.events.inc({"type": "user_update"}) @commands.Cog.listener() async def on_guild_join(self, *_): self.events.inc({"type": "guild_join"}) @commands.Cog.listener() async def on_guild_remove(self, *_): self.events.inc({"type": "guild_remove"}) @commands.Cog.listener() async def on_guild_update(self, *_): self.events.inc({"type": "guild_update"}) @commands.Cog.listener() async def on_guild_role_create(self, *_): self.events.inc({"type": "guild_role_create"}) @commands.Cog.listener() async def on_guild_role_delete(self, *_): self.events.inc({"type": "guild_role_delete"}) @commands.Cog.listener() async def on_guild_role_update(self, *_): self.events.inc({"type": "guild_role_update"}) @commands.Cog.listener() async def on_guild_emojis_update(self, *_): self.events.inc({"type": "guild_emojis_update"}) @commands.Cog.listener() async def on_guild_available(self, *_): self.events.inc({"type": "guild_available"}) @commands.Cog.listener() async def on_guild_unavailable(self, *_): self.events.inc({"type": "guild_unavailable"}) @commands.Cog.listener() async def on_voice_state_update(self, *_): self.events.inc({"type": "voice_state_update"}) @commands.Cog.listener() async def on_member_ban(self, *_): self.events.inc({"type": "member_ban"}) @commands.Cog.listener() async def on_member_unban(self, *_): self.events.inc({"type": "member_unban"}) @commands.Cog.listener() async def on_invite_create(self, *_): self.events.inc({"type": "invite_create"}) @commands.Cog.listener() async def on_invite_delete(self, *_): self.events.inc({"type": "invite_delete"}) @commands.Cog.listener() async def on_group_join(self, *_): self.events.inc({"type": "group_join"}) @commands.Cog.listener() async def on_group_remove(self, *_): self.events.inc({"type": "group_remove"}) @commands.Cog.listener() async def on_relationship_add(self, *_): self.events.inc({"type": "relationship_add"}) @commands.Cog.listener() async def on_relationship_remove(self, *_): self.events.inc({"type": "relationship_remove"}) @commands.Cog.listener() async def on_relationship_update(self, *_): self.events.inc({"type": "relationship_update"})
def test_registry_marshall(self): format_times = 3 counter_data = ( ({"c_sample": "1"}, 100), ({"c_sample": "2"}, 200), ({"c_sample": "3"}, 300), ({"c_sample": "1", "c_subsample": "b"}, 400), ) gauge_data = ( ({"g_sample": "1"}, 500), ({"g_sample": "2"}, 600), ({"g_sample": "3"}, 700), ({"g_sample": "1", "g_subsample": "b"}, 800), ) summary_data = ( ({"s_sample": "1"}, range(1000, 2000, 4)), ({"s_sample": "2"}, range(2000, 3000, 20)), ({"s_sample": "3"}, range(3000, 4000, 13)), ({"s_sample": "1", "s_subsample": "b"}, range(4000, 5000, 47)), ) registry = Registry() counter = Counter("counter_test", "A counter.", {"type": "counter"}) gauge = Gauge("gauge_test", "A gauge.", {"type": "gauge"}) summary = Summary("summary_test", "A summary.", {"type": "summary"}) # Add data [counter.set(c[0], c[1]) for c in counter_data] [gauge.set(g[0], g[1]) for g in gauge_data] [summary.add(i[0], s) for i in summary_data for s in i[1]] registry.register(counter) registry.register(gauge) registry.register(summary) valid_regex = r"""# HELP counter_test A counter. # TYPE counter_test counter counter_test{c_sample="1",type="counter"} 100 counter_test{c_sample="2",type="counter"} 200 counter_test{c_sample="3",type="counter"} 300 counter_test{c_sample="1",c_subsample="b",type="counter"} 400 # HELP gauge_test A gauge. # TYPE gauge_test gauge gauge_test{g_sample="1",type="gauge"} 500 gauge_test{g_sample="2",type="gauge"} 600 gauge_test{g_sample="3",type="gauge"} 700 gauge_test{g_sample="1",g_subsample="b",type="gauge"} 800 # HELP summary_test A summary. # TYPE summary_test summary summary_test{quantile="0.5",s_sample="1",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="1",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.99",s_sample="1",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="1",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="1",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.5",s_sample="2",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="2",type="summary"} 2\d*(?:.\d*)? summary_test{quantile="0.99",s_sample="2",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="2",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="2",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.5",s_sample="3",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="3",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.99",s_sample="3",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="3",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="3",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.5",s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.99",s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? """ f = TextFormatter() self.maxDiff = None # Check multiple times to ensure multiple calls to marshalling # produce the same results for i in range(format_times): self.assertTrue(re.match(valid_regex, f.marshall(registry).decode()))
log_record["timestamp"] = now if log_record.get("level"): log_record["level"] = log_record["level"].upper() else: log_record["level"] = record.levelname logger = logging.getLogger() logger.setLevel(getenv("LOG_LEVEL") or 10) logHandler = logging.StreamHandler() formatter = CustomJsonFormatter() logHandler.setFormatter(formatter) logger.addHandler(logHandler) prometheus_service = Service() prometheus_service.registry = Registry() prometheus_labels = { "host": gethostname(), } ping_counter = Counter("health_check_counter", "total ping requests.") latency_metric = Histogram( "request_latency_seconds", "request latency in seconds.", const_labels=prometheus_labels, buckets=[0.1, 0.5, 1.0, 5.0], ) ram_metric = Gauge("memory_usage_bytes", "memory usage in bytes.", const_labels=prometheus_labels) cpu_metric = Gauge("cpu_usage_percent", "cpu usage percent.",
""" Sometimes you want to expose Prometheus metrics from within an existing web service and don't want to start a separate Prometheus metrics server. This example uses the aioprometheus package to add Prometheus instrumentation to a Vibora application. In this example a registry and a counter metric is instantiated. A '/metrics' route is added to the application and the render function from aioprometheus is called to format the metrics into the appropriate format. """ from aioprometheus import render, Counter, Registry from vibora import Vibora, Request, Response app = Vibora(__name__) app.registry = Registry() app.events_counter = Counter("events", "Number of events.") app.registry.register(app.events_counter) @app.route("/") async def hello(request: Request): app.events_counter.inc({"path": "/"}) return Response(b"hello") @app.route("/metrics") async def handle_metrics(request: Request): """ Negotiate a response format by inspecting the ACCEPTS headers and selecting the most efficient format. Render metrics in the registry into the chosen
async def setUp(self): self.registry = Registry() self.server = Service(registry=self.registry) await self.server.start(addr="127.0.0.1") self.metrics_url = self.server.metrics_url self.root_url = self.server.root_url
class TestTextExporter(AsyncioTestCase): async def setUp(self): self.registry = Registry() self.server = Service(registry=self.registry, loop=self.loop) await self.server.start(addr=TEST_HOST, port=TEST_PORT) self.metrics_url = self.server.url async def tearDown(self): await self.server.stop() async def test_invalid_registry(self): ''' check only valid registry can be provided ''' for invalid_registry in ['nope', dict(), list()]: with self.assertRaises(Exception) as cm: Service(registry=invalid_registry, loop=self.loop) self.assertIn('registry must be a Registry, got:', str(cm.exception)) Service(registry=Registry(), loop=self.loop) async def test_counter(self): # Add some metrics data = ( ({ 'data': 1 }, 100), ({ 'data': "2" }, 200), ({ 'data': 3 }, 300), ({ 'data': 1 }, 400), ) c = Counter("test_counter", "Test Counter.", {'test': "test_counter"}) self.registry.register(c) for i in data: c.set(i[0], i[1]) expected_data = """# HELP test_counter Test Counter. # TYPE test_counter counter test_counter{data="1",test="test_counter"} 400 test_counter{data="2",test="test_counter"} 200 test_counter{data="3",test="test_counter"} 300 """ with aiohttp.ClientSession(loop=self.loop) as session: headers = {ACCEPT: 'text/plain; version=0.0.4'} async with session.get(self.metrics_url, headers=headers) as resp: assert resp.status == 200 content = await resp.read() self.assertEqual("text/plain; version=0.0.4; charset=utf-8", resp.headers.get(CONTENT_TYPE)) self.assertEqual(200, resp.status) self.assertEqual(expected_data, content.decode()) async def test_gauge(self): # Add some metrics data = ( ({ 'data': 1 }, 100), ({ 'data': "2" }, 200), ({ 'data': 3 }, 300), ({ 'data': 1 }, 400), ) g = Gauge("test_gauge", "Test Gauge.", {'test': "test_gauge"}) self.registry.register(g) for i in data: g.set(i[0], i[1]) expected_data = """# HELP test_gauge Test Gauge. # TYPE test_gauge gauge test_gauge{data="1",test="test_gauge"} 400 test_gauge{data="2",test="test_gauge"} 200 test_gauge{data="3",test="test_gauge"} 300 """ with aiohttp.ClientSession(loop=self.loop) as session: headers = {ACCEPT: 'text/plain; version=0.0.4'} async with session.get(self.metrics_url, headers=headers) as resp: assert resp.status == 200 content = await resp.read() self.assertEqual("text/plain; version=0.0.4; charset=utf-8", resp.headers.get(CONTENT_TYPE)) self.assertEqual(200, resp.status) self.assertEqual(expected_data, content.decode()) async def test_summary(self): # Add some metrics data = [3, 5.2, 13, 4] label = {'data': 1} s = Summary("test_summary", "Test Summary.", {'test': "test_summary"}) self.registry.register(s) for i in data: s.add(label, i) expected_data = """# HELP test_summary Test Summary. # TYPE test_summary summary test_summary_count{data="1",test="test_summary"} 4 test_summary_sum{data="1",test="test_summary"} 25.2 test_summary{data="1",quantile="0.5",test="test_summary"} 4.0 test_summary{data="1",quantile="0.9",test="test_summary"} 5.2 test_summary{data="1",quantile="0.99",test="test_summary"} 5.2 """ with aiohttp.ClientSession(loop=self.loop) as session: headers = {ACCEPT: 'text/plain; version=0.0.4'} async with session.get(self.metrics_url, headers=headers) as resp: assert resp.status == 200 content = await resp.read() self.assertEqual("text/plain; version=0.0.4; charset=utf-8", resp.headers.get(CONTENT_TYPE)) self.assertEqual(200, resp.status) self.assertEqual(expected_data, content.decode()) async def test_histogram(self): pass # Add some metrics data = [3, 5.2, 13, 4] label = {'data': 1} h = Histogram("histogram_test", "Test Histogram.", {'type': "test_histogram"}, buckets=[5.0, 10.0, 15.0]) self.registry.register(h) for i in data: h.add(label, i) expected_data = """# HELP histogram_test Test Histogram. # TYPE histogram_test histogram histogram_test_bucket{data="1",le="+Inf",type="test_histogram"} 0 histogram_test_bucket{data="1",le="10.0",type="test_histogram"} 1 histogram_test_bucket{data="1",le="15.0",type="test_histogram"} 1 histogram_test_bucket{data="1",le="5.0",type="test_histogram"} 2 histogram_test_count{data="1",type="test_histogram"} 4 histogram_test_sum{data="1",type="test_histogram"} 25.2 """ with aiohttp.ClientSession(loop=self.loop) as session: headers = {ACCEPT: 'text/plain; version=0.0.4'} async with session.get(self.metrics_url, headers=headers) as resp: assert resp.status == 200 content = await resp.read() self.assertEqual("text/plain; version=0.0.4; charset=utf-8", resp.headers.get(CONTENT_TYPE)) self.assertEqual(200, resp.status) self.assertEqual(expected_data, content.decode()) async def test_all(self): counter_data = ( ({ 'c_sample': '1' }, 100), ({ 'c_sample': '2' }, 200), ({ 'c_sample': '3' }, 300), ({ 'c_sample': '1', 'c_subsample': 'b' }, 400), ) gauge_data = ( ({ 'g_sample': '1' }, 500), ({ 'g_sample': '2' }, 600), ({ 'g_sample': '3' }, 700), ({ 'g_sample': '1', 'g_subsample': 'b' }, 800), ) summary_data = ( ({ 's_sample': '1' }, range(1000, 2000, 4)), ({ 's_sample': '2' }, range(2000, 3000, 20)), ({ 's_sample': '3' }, range(3000, 4000, 13)), ({ 's_sample': '1', 's_subsample': 'b' }, range(4000, 5000, 47)), ) histogram_data = ( ({ 'h_sample': '1' }, range(1, 20, 2)), ({ 'h_sample': '2' }, range(1, 20, 2)), ({ 'h_sample': '3' }, range(1, 20, 2)), ({ 'h_sample': '1', 'h_subsample': 'b' }, range(1, 20, 2)), ) counter = Counter("counter_test", "A counter.", {'type': "counter"}) gauge = Gauge("gauge_test", "A gauge.", {'type': "gauge"}) summary = Summary("summary_test", "A summary.", {'type': "summary"}) histogram = Histogram("histogram_test", "A histogram.", {'type': "histogram"}, buckets=[5.0, 10.0, 15.0]) self.registry.register(counter) self.registry.register(gauge) self.registry.register(summary) self.registry.register(histogram) # Add data [counter.set(c[0], c[1]) for c in counter_data] [gauge.set(g[0], g[1]) for g in gauge_data] [summary.add(i[0], s) for i in summary_data for s in i[1]] [histogram.add(i[0], h) for i in histogram_data for h in i[1]] expected_data = """# HELP counter_test A counter. # TYPE counter_test counter counter_test{c_sample="1",c_subsample="b",type="counter"} 400 counter_test{c_sample="1",type="counter"} 100 counter_test{c_sample="2",type="counter"} 200 counter_test{c_sample="3",type="counter"} 300 # HELP gauge_test A gauge. # TYPE gauge_test gauge gauge_test{g_sample="1",g_subsample="b",type="gauge"} 800 gauge_test{g_sample="1",type="gauge"} 500 gauge_test{g_sample="2",type="gauge"} 600 gauge_test{g_sample="3",type="gauge"} 700 # HELP histogram_test A histogram. # TYPE histogram_test histogram histogram_test_bucket{h_sample="1",h_subsample="b",le="+Inf",type="histogram"} 2 histogram_test_bucket{h_sample="1",h_subsample="b",le="10.0",type="histogram"} 2 histogram_test_bucket{h_sample="1",h_subsample="b",le="15.0",type="histogram"} 3 histogram_test_bucket{h_sample="1",h_subsample="b",le="5.0",type="histogram"} 3 histogram_test_bucket{h_sample="1",le="+Inf",type="histogram"} 2 histogram_test_bucket{h_sample="1",le="10.0",type="histogram"} 2 histogram_test_bucket{h_sample="1",le="15.0",type="histogram"} 3 histogram_test_bucket{h_sample="1",le="5.0",type="histogram"} 3 histogram_test_bucket{h_sample="2",le="+Inf",type="histogram"} 2 histogram_test_bucket{h_sample="2",le="10.0",type="histogram"} 2 histogram_test_bucket{h_sample="2",le="15.0",type="histogram"} 3 histogram_test_bucket{h_sample="2",le="5.0",type="histogram"} 3 histogram_test_bucket{h_sample="3",le="+Inf",type="histogram"} 2 histogram_test_bucket{h_sample="3",le="10.0",type="histogram"} 2 histogram_test_bucket{h_sample="3",le="15.0",type="histogram"} 3 histogram_test_bucket{h_sample="3",le="5.0",type="histogram"} 3 histogram_test_count{h_sample="1",h_subsample="b",type="histogram"} 10 histogram_test_count{h_sample="1",type="histogram"} 10 histogram_test_count{h_sample="2",type="histogram"} 10 histogram_test_count{h_sample="3",type="histogram"} 10 histogram_test_sum{h_sample="1",h_subsample="b",type="histogram"} 100.0 histogram_test_sum{h_sample="1",type="histogram"} 100.0 histogram_test_sum{h_sample="2",type="histogram"} 100.0 histogram_test_sum{h_sample="3",type="histogram"} 100.0 # HELP summary_test A summary. # TYPE summary_test summary summary_test_count{s_sample="1",s_subsample="b",type="summary"} 22 summary_test_count{s_sample="1",type="summary"} 250 summary_test_count{s_sample="2",type="summary"} 50 summary_test_count{s_sample="3",type="summary"} 77 summary_test_sum{s_sample="1",s_subsample="b",type="summary"} 98857.0 summary_test_sum{s_sample="1",type="summary"} 374500.0 summary_test_sum{s_sample="2",type="summary"} 124500.0 summary_test_sum{s_sample="3",type="summary"} 269038.0 summary_test{quantile="0.5",s_sample="1",s_subsample="b",type="summary"} 4235.0 summary_test{quantile="0.5",s_sample="1",type="summary"} 1272.0 summary_test{quantile="0.5",s_sample="2",type="summary"} 2260.0 summary_test{quantile="0.5",s_sample="3",type="summary"} 3260.0 summary_test{quantile="0.9",s_sample="1",s_subsample="b",type="summary"} 4470.0 summary_test{quantile="0.9",s_sample="1",type="summary"} 1452.0 summary_test{quantile="0.9",s_sample="2",type="summary"} 2440.0 summary_test{quantile="0.9",s_sample="3",type="summary"} 3442.0 summary_test{quantile="0.99",s_sample="1",s_subsample="b",type="summary"} 4517.0 summary_test{quantile="0.99",s_sample="1",type="summary"} 1496.0 summary_test{quantile="0.99",s_sample="2",type="summary"} 2500.0 summary_test{quantile="0.99",s_sample="3",type="summary"} 3494.0 """ with aiohttp.ClientSession(loop=self.loop) as session: headers = {ACCEPT: 'text/plain; version=0.0.4'} async with session.get(self.metrics_url, headers=headers) as resp: assert resp.status == 200 content = await resp.read() self.assertEqual("text/plain; version=0.0.4; charset=utf-8", resp.headers.get(CONTENT_TYPE)) self.assertEqual(200, resp.status) self.assertEqual(expected_data, content.decode())
async def setUp(self): self.registry = Registry() self.server = Service(registry=self.registry, loop=self.loop) await self.server.start(addr=TEST_HOST, port=TEST_PORT) self.metrics_url = self.server.url
def test_registry_marshall(self): format_times = 3 counter_data = ( ({ "c_sample": "1" }, 100), ({ "c_sample": "2" }, 200), ({ "c_sample": "3" }, 300), ({ "c_sample": "1", "c_subsample": "b" }, 400), ) gauge_data = ( ({ "g_sample": "1" }, 500), ({ "g_sample": "2" }, 600), ({ "g_sample": "3" }, 700), ({ "g_sample": "1", "g_subsample": "b" }, 800), ) summary_data = ( ({ "s_sample": "1" }, range(1000, 2000, 4)), ({ "s_sample": "2" }, range(2000, 3000, 20)), ({ "s_sample": "3" }, range(3000, 4000, 13)), ({ "s_sample": "1", "s_subsample": "b" }, range(4000, 5000, 47)), ) registry = Registry() counter = Counter("counter_test", "A counter.", {"type": "counter"}) gauge = Gauge("gauge_test", "A gauge.", {"type": "gauge"}) summary = Summary("summary_test", "A summary.", {"type": "summary"}) # Add data [counter.set(c[0], c[1]) for c in counter_data] [gauge.set(g[0], g[1]) for g in gauge_data] [summary.add(i[0], s) for i in summary_data for s in i[1]] registry.register(counter) registry.register(gauge) registry.register(summary) valid_regex = r"""# HELP counter_test A counter. # TYPE counter_test counter counter_test{c_sample="1",type="counter"} 100 counter_test{c_sample="2",type="counter"} 200 counter_test{c_sample="3",type="counter"} 300 counter_test{c_sample="1",c_subsample="b",type="counter"} 400 # HELP gauge_test A gauge. # TYPE gauge_test gauge gauge_test{g_sample="1",type="gauge"} 500 gauge_test{g_sample="2",type="gauge"} 600 gauge_test{g_sample="3",type="gauge"} 700 gauge_test{g_sample="1",g_subsample="b",type="gauge"} 800 # HELP summary_test A summary. # TYPE summary_test summary summary_test{quantile="0.5",s_sample="1",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="1",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.99",s_sample="1",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="1",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="1",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.5",s_sample="2",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="2",type="summary"} 2\d*(?:.\d*)? summary_test{quantile="0.99",s_sample="2",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="2",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="2",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.5",s_sample="3",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="3",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.99",s_sample="3",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="3",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="3",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.5",s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.9",s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test{quantile="0.99",s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test_count{s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? summary_test_sum{s_sample="1",s_subsample="b",type="summary"} \d*(?:.\d*)? """ f = text.TextFormatter() self.maxDiff = None # Check multiple times to ensure multiple calls to marshalling # produce the same results for i in range(format_times): self.assertTrue( re.match(valid_regex, f.marshall(registry).decode()))
def create_app(self): app = Quart(__name__) # Turn off Quart standard logging app.logger.disabled = True log = logging.getLogger("quart.serving") log.disabled = True """ Prometheus Metrics Exporter endpoint. https://github.com/claws/aioprometheus https://github.com/claws/aioprometheus/blob/master/examples/frameworks/quart-example.py The metrics are intended to emulate Stepfunction CloudWatch metrics. https://docs.aws.amazon.com/step-functions/latest/dg/procedure-cw-metrics.html """ registry = Registry() for metric in self.system_metrics.values(): registry.register(metric) for metric in self.execution_metrics.values(): registry.register(metric) for metric in self.task_metrics.values(): registry.register(metric) @app.route("/metrics") async def handle_metrics(): if self.system_metrics: self.system_metrics.collect() content, http_headers = render(registry, request.headers.getlist("accept")) return content, http_headers """ Flask/Quart "catch-all" URL see https://gist.github.com/fitiavana07/bf4eb97b20bbe3853681e153073c0e5e As Quart is an asynchronous framework based on asyncio, it is necessary to explicitly add async and await keywords. The most notable place in which to do this is route functions. see https://pgjones.gitlab.io/quart/how_to_guides/flask_migration.html """ @app.route("/", defaults={"path": ""}, methods=["POST"]) @app.route("/<path:path>", methods=["POST"]) async def handle_post(path): """ Perform initial validation of the HTTP request. The AWS Step Functions API is a slightly "weird" REST API as it mostly seems to rely on POST and rather than using HTTP resources it uses the x-amz-target header to specify the action to be performed. """ if not request.content_type == "application/x-amz-json-1.0": return "Unexpected Content-Type {}".format( request.content_type), 400 target = request.headers.get("x-amz-target") if not target: return "Missing header x-amz-target", 400 if not target.startswith("AWSStepFunctions."): return "Malformed header x-amz-target", 400 action = target.split(".")[1] # print(action) """ request.data is one of the common calls that requires awaiting https://pgjones.gitlab.io/quart/how_to_guides/flask_migration.html """ data = await request.data try: params = json.loads(data.decode("utf8")) except ValueError as e: params = "" self.logger.error( "Message body {} does not contain valid JSON".format(data)) # ------------------------------------------------------------------ """ Define nested functions as handlers for each supported API action. Using nested functions so we can use the context from handle_post. That the methods are prefixed with "aws_api_" is a mitigation against accidentally or deliberately placing an invalid action in the API. """ def aws_api_CreateStateMachine(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_CreateStateMachine.html """ name = params.get("name") if not valid_name(name): self.logger.warning( "RestAPI CreateStateMachine: {} is an invalid name". format(name)) return aws_error("InvalidName"), 400 role_arn = params.get("roleArn") if not valid_role_arn(role_arn): self.logger.warning( "RestAPI CreateStateMachine: {} is an invalid Role ARN" .format(role_arn)) return aws_error("InvalidArn"), 400 # Form stateMachineArn from roleArn and name arn = parse_arn(role_arn) state_machine_arn = create_arn( service="states", region=self.region, account=arn["account"], resource_type="stateMachine", resource=name, ) # Get State Machine type (STANDARD or EXPRESS) if supplied type = params.get("type", "STANDARD") if type not in {"STANDARD", "EXPRESS"}: self.logger.error( "RestAPI CreateStateMachine: State Machine type {} " "is not supported".format(type)) return aws_error("StateMachineTypeNotSupported"), 400 """ Look up stateMachineArn. Use get() not get_cached_view() here as calls to CreateStateMachine might reasonably *expect* no match. """ match = self.asl_store.get(state_machine_arn) if match: # Info seems more appropriate than error here as creation is # an idempotent action. self.logger.info( "RestAPI CreateStateMachine: State Machine {} already exists" .format(state_machine_arn)) return aws_error("StateMachineAlreadyExists"), 400 definition = params.get("definition", "") """ First check if the definition length has exceeded the 1048576 character limit described in the CreateStateMachine API page. https://docs.aws.amazon.com/step-functions/latest/apireference/API_CreateStateMachine.html """ if len(definition ) == 0 or len(definition) > MAX_STATE_MACHINE_LENGTH: self.logger.error( "RestAPI CreateStateMachine: Invalid definition size for State Machine '{}'." .format(name)) return aws_error("InvalidDefinition"), 400 try: definition = json.loads(definition) except ValueError as e: definition = None self.logger.error( "RestAPI CreateStateMachine: State Machine definition {} does not contain valid JSON" .format(params.get("definition"))) return aws_error("InvalidDefinition"), 400 if not (name and role_arn and definition): self.logger.warning( "RestAPI CreateStateMachine: name, roleArn and definition must be specified" ) return aws_error("MissingRequiredParameter"), 400 # TODO ASL Validator?? creation_date = time.time() self.asl_store[state_machine_arn] = { "creationDate": creation_date, "definition": definition, "name": name, "roleArn": role_arn, "stateMachineArn": state_machine_arn, "updateDate": creation_date, "status": "ACTIVE", "type": type, } resp = { "creationDate": creation_date, "stateMachineArn": state_machine_arn, } return jsonify(resp), 200 def aws_api_ListStateMachines(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_ListStateMachines.html """ # TODO handle nextToken stuff next_token = "" """ Populate response using list and dict comprehensions https://www.pythonforbeginners.com/basics/list-comprehensions-in-python https://stackoverflow.com/questions/5352546/extract-subset-of-key-value-pairs-from-python-dictionary-object """ state_machines = [{ k1: v[k1] for k1 in ("creationDate", "name", "stateMachineArn", "type") } for k, v in self.asl_store.items()] resp = {"stateMachines": state_machines} if next_token: resp["nextToken"] = next_token return jsonify(resp), 200 def aws_api_DescribeStateMachine(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_DescribeStateMachine.html """ state_machine_arn = params.get("stateMachineArn") if not state_machine_arn: self.logger.warning( "RestAPI DescribeStateMachine: stateMachineArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_state_machine_arn(state_machine_arn): self.logger.warning( "RestAPI DescribeStateMachine: {} is an invalid State Machine ARN" .format(state_machine_arn)) return aws_error("InvalidArn"), 400 """ Look up stateMachineArn. Using get_cached_view() here means that the state_machine is JSON serialisable, as the cached view is a simple dict rather than say a RedisDict. """ state_machine = self.asl_store.get_cached_view( state_machine_arn) if not state_machine: self.logger.info( "RestAPI DescribeStateMachine: State Machine {} does not exist" .format(state_machine_arn)) return aws_error("StateMachineDoesNotExist"), 400 """ In the API the "definition" field is actually a string not a JSON object, hence the json.dumps() here. We do the conversion here rather than storing it as a string because the State Engine uses the deserialised definition as a key part of its core state transition behaviour. """ resp = state_machine.copy() resp["definition"] = json.dumps(state_machine["definition"]) return jsonify(resp), 200 def aws_api_DescribeStateMachineForExecution(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_DescribeStateMachineForExecution.html """ execution_arn = params.get("executionArn") if not execution_arn: self.logger.warning( "RestAPI DescribeStateMachineForExecution: executionArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_execution_arn(execution_arn): self.logger.warning( "RestAPI DescribeStateMachineForExecution: {} is an invalid Execution ARN" .format(execution_arn)) return aws_error("InvalidArn"), 400 # Look up executionArn execution = self.executions.get(execution_arn) if not execution: self.logger.info( "RestAPI DescribeStateMachineForExecution: Execution {} does not exist" .format(execution_arn)) return aws_error("ExecutionDoesNotExist"), 400 state_machine_arn = execution.get("stateMachineArn") if not valid_state_machine_arn(state_machine_arn): self.logger.warning( "RestAPI DescribeStateMachineForExecution: {} is an invalid State Machine ARN" .format(state_machine_arn)) return aws_error("InvalidArn"), 400 # Look up stateMachineArn state_machine = self.asl_store.get_cached_view( state_machine_arn) if not state_machine: self.logger.info( "RestAPI DescribeStateMachineForExecution: State Machine {} does not exist" .format(state_machine_arn)) return aws_error("StateMachineDoesNotExist"), 400 """ As with DescribeStateMachine the "definition" field is actually a string not a JSON object, hence the json.dumps() here. """ resp = { k: state_machine[k] for k in ("definition", "name", "roleArn", "stateMachineArn", "updateDate") } resp["definition"] = json.dumps(state_machine["definition"]) return jsonify(resp), 200 def aws_api_UpdateStateMachine(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_UpdateStateMachine.html """ state_machine_arn = params.get("stateMachineArn") if not state_machine_arn: self.logger.warning( "RestAPI UpdateStateMachine: stateMachineArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_state_machine_arn(state_machine_arn): self.logger.warning( "RestAPI UpdateStateMachine: {} is an invalid State Machine ARN" .format(state_machine_arn)) return aws_error("InvalidArn"), 400 """ Look up stateMachineArn. Use get() rather than get_cached_view() as we are going to be updating the retrieved State Machine. """ state_machine = self.asl_store.get(state_machine_arn) if not state_machine: self.logger.info( "RestAPI UpdateStateMachine: State Machine {} does not exist" .format(state_machine_arn)) return aws_error("StateMachineDoesNotExist"), 400 role_arn = params.get("roleArn") if role_arn: if not valid_role_arn(role_arn): self.logger.warning( "RestAPI UpdateStateMachine: {} is an invalid Role ARN" .format(role_arn)) return aws_error("InvalidArn"), 400 state_machine["roleArn"] = role_arn definition = params.get("definition", "") if definition: """ First check if the definition length has exceeded the 1048576 character limit described in the UpdateStateMachine API page. https://docs.aws.amazon.com/step-functions/latest/apireference/API_UpdateStateMachine.html """ if len(definition) == 0 or len( definition) > MAX_STATE_MACHINE_LENGTH: self.logger.error( "RestAPI CreateStateMachine: Invalid definition size for State Machine '{}'." .format(name)) return aws_error("InvalidDefinition"), 400 try: definition = json.loads(definition) except ValueError as e: definition = None self.logger.error( "RestAPI UpdateStateMachine: State Machine definition {} does not contain valid JSON" .format(params.get("definition"))) return aws_error("InvalidDefinition"), 400 # TODO ASL Validator?? state_machine["definition"] = definition if not role_arn and not definition: self.logger.warning( "RestAPI UpdateStateMachine: either roleArn or definition must be specified" ) return aws_error("MissingRequiredParameter"), 400 update_date = time.time() state_machine["updateDate"] = update_date self.asl_store[state_machine_arn] = state_machine resp = {"updateDate": update_date} return jsonify(resp), 200 def aws_api_DeleteStateMachine(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_DeleteStateMachine.html TODO This should really mark the state machine for deletion and "The state machine itself is deleted after all executions are completed or deleted." """ state_machine_arn = params.get("stateMachineArn") if not state_machine_arn: self.logger.warning( "RestAPI DeleteStateMachine: stateMachineArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_state_machine_arn(state_machine_arn): self.logger.warning( "RestAPI DeleteStateMachine: {} is an invalid State Machine ARN" .format(state_machine_arn)) return aws_error("InvalidArn"), 400 # Look up stateMachineArn state_machine = self.asl_store.get_cached_view( state_machine_arn) if not state_machine: self.logger.info( "RestAPI DeleteStateMachine: State Machine {} does not exist" .format(state_machine_arn)) return aws_error("StateMachineDoesNotExist"), 400 del self.asl_store[state_machine_arn] return "", 200 def aws_api_StartExecution(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_StartExecution.html """ # print(params) state_machine_arn = params.get("stateMachineArn") if not state_machine_arn: self.logger.warning( "RestAPI StartExecution: stateMachineArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_state_machine_arn(state_machine_arn): self.logger.warning( "RestAPI StartExecution: {} is an invalid State Machine ARN" .format(state_machine_arn)) return aws_error("InvalidArn"), 400 """ If name isn't provided create one from a UUID. TODO names should be unique within a 90 day period, at the moment there is no code to check for uniqueness of provided names so client code that doesn't honour this may currently succeed in this implementation but fail if calling real AWS StepFunctions. """ name = params.get("name", str(uuid.uuid4())) if not valid_name(name): self.logger.warning( "RestAPI StartExecution: {} is an invalid name".format( name)) return aws_error("InvalidName"), 400 input = params.get("input", "{}") """ First check if the input length has exceeded the 262144 character quota described in Stepfunction Quotas page. https://docs.aws.amazon.com/step-functions/latest/dg/limits.html """ if len(input) > MAX_DATA_LENGTH: self.logger.error( "RestAPI StartExecution: input size for execution '{}' exceeds " "the maximum number of characters service limit.". format(name)) return aws_error("InvalidExecutionInput"), 400 try: input = json.loads(input) except TypeError as e: self.logger.error( "RestAPI StartExecution: Invalid input, {}".format(e)) return aws_error("InvalidExecutionInput"), 400 except ValueError as e: self.logger.error( "RestAPI StartExecution: input {} does not contain valid JSON" .format(input)) return aws_error("InvalidExecutionInput"), 400 # Look up stateMachineArn state_machine = self.asl_store.get_cached_view( state_machine_arn) if not state_machine: self.logger.info( "RestAPI StartExecution: State Machine {} does not exist" .format(state_machine_arn)) return aws_error("StateMachineDoesNotExist"), 400 # Form executionArn from stateMachineArn and name arn = parse_arn(state_machine_arn) execution_arn = create_arn( service="states", region=arn.get("region", self.region), account=arn["account"], resource_type="execution", resource=arn["resource"] + ":" + name, ) with opentracing.tracer.start_active_span( operation_name="StartExecution:ExecutionLaunching", child_of=span_context("http_headers", request.headers, self.logger), tags={ "component": "rest_api", "execution_arn": execution_arn }) as scope: """ The application context is described in the AWS documentation: https://docs.aws.amazon.com/step-functions/latest/dg/input-output-contextobject.html """ # https://stackoverflow.com/questions/8556398/generate-rfc-3339-timestamp-in-python start_time = datetime.now( timezone.utc).astimezone().isoformat() context = { "Tracer": inject_span("text_map", scope.span, self.logger), "Execution": { "Id": execution_arn, "Input": input, "Name": name, "RoleArn": state_machine.get("roleArn"), "StartTime": start_time, }, "State": { "EnteredTime": start_time, "Name": "" }, # Start state "StateMachine": { "Id": state_machine_arn, "Name": state_machine.get("name"), }, } event = {"data": input, "context": context} """ threadsafe=True is important here as the RestAPI runs in a different thread to the main event_dispatcher loop. """ try: self.event_dispatcher.publish(event, threadsafe=True, start_execution=True) except: message = ( "RestAPI StartExecution: Internal messaging " "error, start message could not be published.") self.logger.error(message) return aws_error("InternalError", message), 500 resp = { "executionArn": execution_arn, "startDate": time.time() } return jsonify(resp), 200 def aws_api_ListExecutions(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_ListExecutions.html """ state_machine_arn = params.get("stateMachineArn") if not state_machine_arn: self.logger.warning( "RestAPI ListExecutions: stateMachineArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_state_machine_arn(state_machine_arn): self.logger.warning( "RestAPI ListExecutions: {} is an invalid State Machine ARN" .format(state_machine_arn)) return aws_error("InvalidArn"), 400 # Look up stateMachineArn state_machine = self.asl_store.get_cached_view( state_machine_arn) if not state_machine: self.logger.info( "RestAPI ListExecutions: State Machine {} does not exist" .format(state_machine_arn)) return aws_error("StateMachineDoesNotExist"), 400 status_filter = params.get("statusFilter") if status_filter and status_filter not in { "RUNNING", "SUCCEEDED", "FAILED", "TIMED_OUT", "ABORTED", }: status_filter = None """ Populate response using list and dict comprehensions https://www.pythonforbeginners.com/basics/list-comprehensions-in-python https://stackoverflow.com/questions/5352546/extract-subset-of-key-value-pairs-from-python-dictionary-object TODO handle nextToken stuff. Note that ListExecutions is potentially a very expensive call as there might well be a large number of executions for any given State Machine and moreover the execution details are stored as Redis hashes that are themselves keyed by the execution ARN. In other words it is not *natively* a list and under the covers listing the executions is implemented by a redis.scan. One option for improving things might be to use the next_token to wrap a scan cursor. That approach should works as the maxResults in the API call is only a hint and the actual number of results returned per call might be fewer than the specified maximum, so that fits somewhat to the constraints of Redis scan cursors. """ next_token = "" executions = [ { k1: v[k1] for k1 in ("executionArn", "name", "startDate", "stateMachineArn", "status", "stopDate") } for k, v in self.executions.items() if v["stateMachineArn"] == state_machine_arn and ( status_filter == None or v["status"] == status_filter) ] resp = {"executions": executions} if next_token: resp["nextToken"] = next_token return jsonify(resp), 200 def aws_api_DescribeExecution(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_DescribeExecution.html """ execution_arn = params.get("executionArn") if not execution_arn: self.logger.warning( "RestAPI DescribeExecution: executionArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_execution_arn(execution_arn): self.logger.warning( "RestAPI DescribeExecution: {} is an invalid Execution ARN" .format(execution_arn)) return aws_error("InvalidArn"), 400 # Look up executionArn execution = self.executions.get(execution_arn) if not execution: self.logger.info( "RestAPI DescribeExecution: Execution {} does not exist" .format(execution_arn)) return aws_error("ExecutionDoesNotExist"), 400 if not isinstance(execution, dict): # May be (non JSON) RedisDict execution = dict(execution) return jsonify(execution), 200 def aws_api_GetExecutionHistory(): """ https://docs.aws.amazon.com/step-functions/latest/apireference/API_GetExecutionHistory.html """ # print(params) execution_arn = params.get("executionArn") if not execution_arn: self.logger.warning( "RestAPI GetExecutionHistory: executionArn must be specified" ) return aws_error("MissingRequiredParameter"), 400 if not valid_execution_arn(execution_arn): self.logger.warning( "RestAPI GetExecutionHistory: {} is an invalid Execution ARN" .format(execution_arn)) return aws_error("InvalidArn"), 400 reverse_order = params.get("reverseOrder", False) # Look up executionArn history = self.execution_history.get(execution_arn) if not history: self.logger.info( "RestAPI GetExecutionHistory: Execution {} does not exist" .format(execution_arn)) return aws_error("ExecutionDoesNotExist"), 400 """ Reverse via slicing: [start:stop:step] so step is -1 https://stackoverflow.com/questions/3940128/how-can-i-reverse-a-list-in-python TODO handle nextToken stuff. Note that GetExecutionHistory is potentially an expensive call if the history is large. The store self.execution_history has list semantics, but is backed by an external (e.g. Redis) store. Under the covers it will do a redis.lrange, so the next_token behaviour when implemented should "slice" the appropriate range. Note that doing this for GetExecutionHistory should be easier than for ListExecutions - see comment in ListExecutions for why. """ if reverse_order: history = history[::-1] else: history = history[:] next_token = "" resp = {"events": history} if next_token: resp["nextToken"] = next_token return jsonify(resp), 200 def aws_api_InvalidAction(): self.logger.error("RestAPI invalid action: {}".format(action)) return "InvalidAction", 400 # ------------------------------------------------------------------ # Use the API action to dynamically invoke the appropriate handler. try: value, code = locals().get("aws_api_" + action, aws_api_InvalidAction)() return value, code except Exception as e: self.logger.error( "RestAPI action {} failed unexpectedly with exception: {}". format(action, e)) return "InternalError", 500 return app