def test_filter_records_filtered(self): record_aggregator = RecordAggregator() # Test for malformed inputs with self.assertRaises(Exception): record_aggregator.filter_records(filters=[(lambda x: False)]) with self.assertRaises(Exception): record_aggregator.filter_records(record_types=[None, None], filters=[(lambda x: False)]) # Insert 3 throughputs record_aggregator.insert( PerfThroughput("Throughput: 5 infer/sec\n\n\n\n")) record_aggregator.insert( PerfThroughput("Throughput: 1 infer/sec\n\n\n\n")) record_aggregator.insert( PerfThroughput("Throughput: 10 infer/sec\n\n\n\n")) # Test get with filters retrieved_records = record_aggregator.filter_records( record_types=[PerfThroughput], filters=[(lambda v: v.value() >= 5)]).get_records() # Should return 2 records self.assertEqual(len(retrieved_records[PerfThroughput]), 2) retrieved_values = [ record.value() for record in retrieved_records[PerfThroughput] ] self.assertIn(5, retrieved_values) self.assertIn(10, retrieved_values) # Insert 2 Latency records record_aggregator.insert(PerfLatency("Avg latency: 3 ms\n\n\n\n")) record_aggregator.insert(PerfLatency("Avg latency: 6 ms\n\n\n\n")) # Test get with multiple headers retrieved_records = record_aggregator.filter_records( record_types=[PerfLatency, PerfThroughput], filters=[(lambda v: v.value() == 3), (lambda v: v.value() < 5)]).get_records() retrieved_values = { record_type: [record.value() for record in retrieved_records[record_type]] for record_type in [PerfLatency, PerfThroughput] } self.assertEqual(len(retrieved_records[PerfLatency]), 1) self.assertIn(3, retrieved_values[PerfLatency]) self.assertEqual(len(retrieved_records[PerfThroughput]), 1) self.assertIn(1, retrieved_values[PerfThroughput])
def test_filter_records_default(self): record_aggregator = RecordAggregator() # insert throughput record and check its presence throughput_record = PerfThroughput(5) record_aggregator.insert(throughput_record) # Get the record retrieved_records = record_aggregator.filter_records() retrieved_throughput = retrieved_records[throughput_record.header()][0] self.assertEqual(retrieved_throughput.header(), throughput_record.header(), msg="Headers do not match after filter_records") self.assertEqual(retrieved_throughput.value(), throughput_record.value(), msg="Values do not match after filter_records")
def test_filter_records_default(self): record_aggregator = RecordAggregator() # insert throughput record and check its presence throughput_record = PerfThroughput("Throughput: 5 infer/sec\n\n\n\n") record_aggregator.insert(throughput_record) # Get the record retrieved_records = record_aggregator.filter_records() retrieved_throughput = retrieved_records[PerfThroughput][0] self.assertIsInstance( retrieved_throughput, PerfThroughput, msg="Record types do not match after filter_records") self.assertEqual(retrieved_throughput.value(), throughput_record.value(), msg="Values do not match after filter_records")
def test_filter_records_filtered(self): record_aggregator = RecordAggregator() # Test for malformed inputs with self.assertRaises(Exception): record_aggregator.filter_records(filters=[(lambda x: False)]) with self.assertRaises(Exception): record_aggregator.filter_records(headers=["header1", "header2"], filters=[(lambda x: False)]) # Insert 3 throughputs throughput_record = PerfThroughput(5) record_aggregator.insert(throughput_record) record_aggregator.insert(PerfThroughput(1)) record_aggregator.insert(PerfThroughput(10)) # Test get with filters retrieved_records = record_aggregator.filter_records( headers=[throughput_record.header()], filters=[(lambda v: v >= 5)]) # Should return 2 records self.assertEqual(len(retrieved_records[throughput_record.header()]), 2) # Insert 2 Latency records latency_record = PerfLatency(3) record_aggregator.insert(latency_record) record_aggregator.insert(PerfLatency(6)) # Test get with multiple headers retrieved_records = record_aggregator.filter_records( headers=[latency_record.header(), throughput_record.header()], filters=[(lambda v: v == 3), (lambda v: v < 5)]) self.assertEqual(len(retrieved_records[throughput_record.header()]), 1) self.assertEqual(len(retrieved_records[latency_record.header()]), 1)