class TestRequestTracker(MatrixTestCaseUsingMockAWS): @mock.patch("matrix.common.date.get_datetime_now") def setUp(self, mock_get_datetime_now): super(TestRequestTracker, self).setUp() self.stub_date = '2019-03-18T180907.136216Z' mock_get_datetime_now.return_value = self.stub_date self.request_id = str(uuid.uuid4()) self.request_tracker = RequestTracker(self.request_id) self.dynamo_handler = DynamoHandler() self.create_test_data_version_table() self.create_test_deployment_table() self.create_test_request_table() self.create_s3_results_bucket() self.init_test_data_version_table() self.init_test_deployment_table() self.dynamo_handler.create_request_table_entry( self.request_id, "test_format", ["test_field_1", "test_field_2"], "test_feature") def test_is_initialized(self): self.assertTrue(self.request_tracker.is_initialized) new_request_tracker = RequestTracker("test_uuid") self.assertFalse(new_request_tracker.is_initialized) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.generate_request_hash" ) def test_request_hash(self, mock_generate_request_hash): with self.subTest("Test skip generation in API deployments:"): os.environ['MATRIX_VERSION'] = "test_version" self.assertEqual(self.request_tracker.request_hash, "N/A") mock_generate_request_hash.assert_not_called() stored_request_hash = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.REQUEST_HASH.value] self.assertEqual(self.request_tracker._request_hash, "N/A") self.assertEqual(stored_request_hash, "N/A") del os.environ['MATRIX_VERSION'] with self.subTest( "Test generation and storage in Dynamo on first access"): mock_generate_request_hash.return_value = "test_hash" self.assertEqual(self.request_tracker.request_hash, "test_hash") mock_generate_request_hash.assert_called_once() stored_request_hash = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.REQUEST_HASH.value] self.assertEqual(self.request_tracker._request_hash, "test_hash") self.assertEqual(stored_request_hash, "test_hash") with self.subTest("Test immediate retrieval on future accesses"): self.assertEqual(self.request_tracker.request_hash, "test_hash") mock_generate_request_hash.assert_called_once() @mock.patch( "matrix.common.request.request_tracker.RequestTracker.request_hash", new_callable=mock.PropertyMock) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.data_version", new_callable=mock.PropertyMock) def test_s3_results_prefix(self, mock_data_version, mock_request_hash): mock_data_version.return_value = "test_data_version" mock_request_hash.return_value = "test_request_hash" self.assertEqual(self.request_tracker.s3_results_prefix, "test_data_version/test_request_hash") @mock.patch("matrix.common.request.request_tracker.RequestTracker.format", new_callable=mock.PropertyMock) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.request_hash", new_callable=mock.PropertyMock) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.data_version", new_callable=mock.PropertyMock) def test_s3_results_key(self, mock_data_version, mock_request_hash, mock_format): mock_data_version.return_value = "test_data_version" mock_request_hash.return_value = "test_request_hash" mock_format.return_value = "loom" self.assertEqual( self.request_tracker.s3_results_key, f"test_data_version/test_request_hash/{self.request_id}.loom") mock_format.return_value = "csv" self.assertEqual( self.request_tracker.s3_results_key, f"test_data_version/test_request_hash/{self.request_id}.csv.zip") mock_format.return_value = "mtx" self.assertEqual( self.request_tracker.s3_results_key, f"test_data_version/test_request_hash/{self.request_id}.mtx.zip") @mock.patch("matrix.common.aws.dynamo_handler.DynamoHandler.get_table_item" ) def test_data_version(self, mock_get_table_item): mock_get_table_item.return_value = { RequestTableField.DATA_VERSION.value: 0 } with self.subTest("Test Dynamo read on first access"): self.assertEqual(self.request_tracker.data_version, 0) mock_get_table_item.assert_called_once() with self.subTest("Test cached access on successive reads"): self.assertEqual(self.request_tracker.data_version, 0) mock_get_table_item.assert_called_once() def test_format(self): self.assertEqual(self.request_tracker.format, "test_format") def test_metadata_fields(self): self.assertEqual(self.request_tracker.metadata_fields, ["test_field_1", "test_field_2"]) def test_feature(self): self.assertEqual(self.request_tracker.feature, "test_feature") def test_batch_job_id(self): self.assertEqual(self.request_tracker.batch_job_id, None) field_enum = RequestTableField.BATCH_JOB_ID self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123") self.assertEqual(self.request_tracker.batch_job_id, "123-123") @mock.patch( "matrix.common.aws.batch_handler.BatchHandler.get_batch_job_status") def test_batch_job_status(self, mock_get_job_status): mock_get_job_status.return_value = "FAILED" field_enum = RequestTableField.BATCH_JOB_ID self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123") self.assertEqual(self.request_tracker.batch_job_status, "FAILED") @mock.patch( "matrix.common.request.request_tracker.RequestTracker.num_bundles", new_callable=mock.PropertyMock) def test_num_bundles_interval(self, mock_num_bundles): mock_num_bundles.return_value = 0 self.assertEqual(self.request_tracker.num_bundles_interval, "0-499") mock_num_bundles.return_value = 1 self.assertEqual(self.request_tracker.num_bundles_interval, "0-499") mock_num_bundles.return_value = 500 self.assertEqual(self.request_tracker.num_bundles_interval, "500-999") mock_num_bundles.return_value = 1234 self.assertEqual(self.request_tracker.num_bundles_interval, "1000-1499") def test_creation_date(self): self.assertEqual(self.request_tracker.creation_date, self.stub_date) @mock.patch( "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data" ) def test_error(self, mock_cw_put): self.assertEqual(self.request_tracker.error, "") self.request_tracker.log_error("test error") self.assertEqual(self.request_tracker.error, "test error") mock_cw_put.assert_called_once_with( metric_name=MetricName.REQUEST_ERROR, metric_value=1) @mock.patch( "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data" ) @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.create_request_table_entry" ) def test_initialize_request(self, mock_create_request_table_entry, mock_create_cw_metric): self.request_tracker.initialize_request("test_format") mock_create_request_table_entry.assert_called_once_with( self.request_id, "test_format", DEFAULT_FIELDS, DEFAULT_FEATURE) mock_create_cw_metric.assert_called_once() @mock.patch( "matrix.common.request.request_tracker.RequestTracker.metadata_fields", new_callable=mock.PropertyMock) @mock.patch( "matrix.common.query.cell_query_results_reader.CellQueryResultsReader.load_results" ) @mock.patch( "matrix.common.query.query_results_reader.QueryResultsReader._parse_manifest" ) def test_generate_request_hash(self, mock_parse_manifest, mock_load_results, mock_metadata_fields): mock_load_results.return_value = pandas.DataFrame( index=["test_cell_key_1", "test_cell_key_2"]) mock_metadata_fields.return_value = ["test_field_1", "test_field_2"] h = hashlib.md5() h.update(self.request_tracker.feature.encode()) h.update(self.request_tracker.format.encode()) h.update("test_field_1".encode()) h.update("test_field_2".encode()) h.update("test_cell_key_1".encode()) h.update("test_cell_key_2".encode()) self.assertEqual(self.request_tracker.generate_request_hash(), h.hexdigest()) @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field") def test_expect_subtask_execution(self, mock_increment_table_field): self.request_tracker.expect_subtask_execution(Subtask.DRIVER) mock_increment_table_field.assert_called_once_with( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.EXPECTED_DRIVER_EXECUTIONS, 1) @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field") def test_complete_subtask_execution(self, mock_increment_table_field): self.request_tracker.complete_subtask_execution(Subtask.DRIVER) mock_increment_table_field.assert_called_once_with( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.COMPLETED_DRIVER_EXECUTIONS, 1) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.s3_results_prefix", new_callable=mock.PropertyMock) def test_lookup_cached_result(self, mock_s3_results_prefix): mock_s3_results_prefix.return_value = "test_prefix" s3_handler = S3Handler(os.environ['MATRIX_RESULTS_BUCKET']) with self.subTest("Do not match in S3 'directories'"): s3_handler.store_content_in_s3("test_prefix", "test_content") self.assertEqual(self.request_tracker.lookup_cached_result(), "") with self.subTest("Successfully retrieve a result key"): s3_handler.store_content_in_s3("test_prefix/test_result_1", "test_content") s3_handler.store_content_in_s3("test_prefix/test_result_2", "test_content") self.assertEqual(self.request_tracker.lookup_cached_result(), "test_prefix/test_result_1") def test_is_request_complete(self): self.assertFalse(self.request_tracker.is_request_complete()) s3_handler = S3Handler(os.environ['MATRIX_RESULTS_BUCKET']) s3_handler.store_content_in_s3( f"{self.request_tracker.s3_results_key}/{self.request_id}.{self.request_tracker.format}", "") self.assertTrue(self.request_tracker.is_request_complete()) def test_is_request_ready_for_conversion(self): self.assertFalse( self.request_tracker.is_request_ready_for_conversion()) self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.COMPLETED_QUERY_EXECUTIONS, 3) self.assertTrue(self.request_tracker.is_request_ready_for_conversion()) @mock.patch( "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data" ) def test_complete_request(self, mock_cw_put): duration = 1 self.request_tracker.complete_request(duration) expected_calls = [ mock.call(metric_name=MetricName.CONVERSION_COMPLETION, metric_value=1), mock.call(metric_name=MetricName.REQUEST_COMPLETION, metric_value=1), mock.call(metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': mock.ANY }, { 'Name': "Output Format", 'Value': mock.ANY }, ]), mock.call(metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': mock.ANY }, ]) ] mock_cw_put.assert_has_calls(expected_calls) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.log_error") @mock.patch( "matrix.common.request.request_tracker.RequestTracker.creation_date", new_callable=mock.PropertyMock) @mock.patch("matrix.common.aws.s3_handler.S3Handler.exists") def test_is_expired(self, mock_exists, mock_creation_date, mock_log_error): with self.subTest("Expired"): mock_exists.return_value = False mock_creation_date.return_value = date.to_string( date.get_datetime_now() - timedelta(days=30, minutes=1)) self.assertTrue(self.request_tracker.is_expired) mock_log_error.assert_called_once() mock_log_error.reset_mock() with self.subTest( "Not expired. Matrix DNE but not past expiration date"): mock_exists.return_value = False mock_creation_date.return_value = date.to_string( date.get_datetime_now() - timedelta(days=29)) self.assertFalse(self.request_tracker.is_expired) mock_log_error.assert_not_called() with self.subTest("Not expired. Matrix exists"): mock_exists.return_value = True self.assertFalse(self.request_tracker.is_expired) mock_log_error.assert_not_called() @mock.patch( "matrix.common.request.request_tracker.RequestTracker.log_error") @mock.patch( "matrix.common.request.request_tracker.RequestTracker.creation_date", new_callable=mock.PropertyMock) def test_timeout(self, mock_creation_date, mock_log_error): # no timeout mock_creation_date.return_value = date.to_string( date.get_datetime_now() - timedelta(hours=35, minutes=59)) self.assertFalse(self.request_tracker.timeout) # timeout mock_creation_date.return_value = date.to_string( date.get_datetime_now() - timedelta(hours=36, minutes=1)) self.assertTrue(self.request_tracker.timeout) mock_log_error.assert_called_once() @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.set_table_field_with_value" ) def test_write_batch_job_id_to_db(self, mock_set_table_field_with_value): self.request_tracker.write_batch_job_id_to_db("123-123") mock_set_table_field_with_value.assert_called_once_with( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.BATCH_JOB_ID, "123-123")
class RequestTracker: """ Provides an interface for tracking a request's parameters and state. """ def __init__(self, request_id: str): Logging.set_correlation_id(logger, request_id) self.request_id = request_id self._request_hash = "N/A" self._data_version = None self._num_bundles = None self._format = None self._metadata_fields = None self._feature = None self.dynamo_handler = DynamoHandler() self.cloudwatch_handler = CloudwatchHandler() self.batch_handler = BatchHandler() @property def is_initialized(self) -> bool: try: self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id) except MatrixException: return False return True @property def request_hash(self) -> str: """ Unique hash generated using request parameters. If a request hash does not exist, one will be attempted to be generated. :return: str Request hash """ if self._request_hash == "N/A": self._request_hash = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.REQUEST_HASH.value] # Do not generate request hash in API requests to avoid timeouts. # Presence of MATRIX_VERSION indicates API deployment. if self._request_hash == "N/A" and not os.getenv('MATRIX_VERSION'): try: self._request_hash = self.generate_request_hash() self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.REQUEST_HASH, self._request_hash) except MatrixQueryResultsNotFound as e: logger.warning(f"Failed to generate a request hash. {e}") return self._request_hash @property def s3_results_prefix(self) -> str: """ The S3 prefix where results for this request hash are stored in the results bucket. :return: str S3 prefix """ return f"{self.data_version}/{self.request_hash}" @property def s3_results_key(self) -> str: """ The S3 key where matrix results for this request are stored in the results bucket. :return: str S3 key """ is_compressed = self.format == MatrixFormat.CSV.value or self.format == MatrixFormat.MTX.value return f"{self.data_version}/{self.request_hash}/{self.request_id}.{self.format}" + \ (".zip" if is_compressed else "") @property def data_version(self) -> int: """ The Redshift data version this request is generated on. :return: int Data version """ if self._data_version is None: self._data_version = \ self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.DATA_VERSION.value] return self._data_version @property def num_bundles(self) -> int: """ The number of bundles in the request. :return: int Number of bundles """ if not self._num_bundles: self._num_bundles = \ self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.NUM_BUNDLES.value] return self._num_bundles @property def num_bundles_interval(self) -> str: """ Returns the interval string that num_bundles corresponds to. :return: the interval string e.g. "0-499" """ interval_size = 500 index = int(self.num_bundles / interval_size) return f"{index * interval_size}-{(index * interval_size) + interval_size - 1}" @property def format(self) -> str: """ The request's user specified output file format of the resultant expression matrix. :return: str The file format (one of MatrixFormat) """ if not self._format: self._format = \ self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.FORMAT.value] return self._format @property def metadata_fields(self) -> list: """ The request's user-specified list of metadata fields to include in the resultant expression matrix. :return: list List of metadata fields """ if not self._metadata_fields: self._metadata_fields = \ self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.METADATA_FIELDS.value] return self._metadata_fields @property def feature(self) -> str: """ The request's user-specified feature type (gene|transcript) of the resultant expression matrix. :return: str Feature (gene|transcript) """ if not self._feature: self._feature = \ self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.FEATURE.value] return self._feature @property def batch_job_id(self) -> str: """ The batch job id for matrix conversion corresponding with a request. :return: str The batch job id """ table_item = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, key=self.request_id) batch_job_id = table_item.get(RequestTableField.BATCH_JOB_ID.value) if not batch_job_id or batch_job_id == "N/A": return None else: return batch_job_id @property def batch_job_status(self) -> str: """ The batch job status for matrix conversion corresponding with a request. :return: str The batch job status """ status = None if self.batch_job_id: status = self.batch_handler.get_batch_job_status(self.batch_job_id) return status @property def creation_date(self) -> str: """ The creation date of matrix service request. :return: str creation date """ return self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.CREATION_DATE.value] @property def is_expired(self): """ Whether or not the request has expired and the matrix in S3 has been deleted. :return: bool """ s3_results_bucket_handler = S3Handler( os.environ['MATRIX_RESULTS_BUCKET']) is_past_expiration = date.to_datetime( self.creation_date) < date.get_datetime_now() - timedelta(days=30) is_expired = not s3_results_bucket_handler.exists( self.s3_results_key) and is_past_expiration if is_expired: self.log_error( "This request has expired after 30 days and is no longer available for download. " "A new matrix can be generated by resubmitting the POST request to /v1/matrix." ) return is_expired @property def timeout(self) -> bool: timeout = date.to_datetime( self.creation_date) < date.get_datetime_now() - timedelta(hours=36) if timeout: self.log_error( "This request has timed out after 12 hours." "Please try again by resubmitting the POST request.") return timeout @property def error(self) -> str: """ The user-friendly message describing the latest error the request raised. :return: str The error message if one exists, else empty string """ error = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, key=self.request_id)[RequestTableField.ERROR_MESSAGE.value] return error if error else "" def initialize_request(self, fmt: str, metadata_fields: list = DEFAULT_FIELDS, feature: str = DEFAULT_FEATURE) -> None: """Initialize the request id in the request state table. Put request metric to cloudwatch. :param fmt: Request output format for matrix conversion :param metadata_fields: Metadata fields to include in expression matrix :param feature: Feature type to generate expression counts for (one of MatrixFeature) """ self.dynamo_handler.create_request_table_entry(self.request_id, fmt, metadata_fields, feature) self.cloudwatch_handler.put_metric_data(metric_name=MetricName.REQUEST, metric_value=1) def generate_request_hash(self) -> str: """ Generates a request hash uniquely identifying a request by its input parameters. Requires cell query results to exist, else raises MatrixQueryResultsNotFound. :return: str Request hash """ cell_manifest_key = f"s3://{os.environ['MATRIX_QUERY_RESULTS_BUCKET']}/{self.request_id}/cell_metadata_manifest" reader = CellQueryResultsReader(cell_manifest_key) cell_df = reader.load_results() cellkeys = cell_df.index h = hashlib.md5() h.update(self.feature.encode()) h.update(self.format.encode()) for field in self.metadata_fields: h.update(field.encode()) for key in cellkeys: h.update(key.encode()) request_hash = h.hexdigest() return request_hash def expect_subtask_execution(self, subtask: Subtask): """ Expect the execution of 1 Subtask by tracking it in DynamoDB. A Subtask is executed either by a Lambda or AWS Batch. :param subtask: The expected Subtask to be executed. """ subtask_to_dynamo_field_name = { Subtask.DRIVER: RequestTableField.EXPECTED_DRIVER_EXECUTIONS, Subtask.CONVERTER: RequestTableField.EXPECTED_CONVERTER_EXECUTIONS, } self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, subtask_to_dynamo_field_name[subtask], 1) def complete_subtask_execution(self, subtask: Subtask): """ Counts the completed execution of 1 Subtask in DynamoDB. A Subtask is executed either by a Lambda or AWS Batch. :param subtask: The executed Subtask. """ subtask_to_dynamo_field_name = { Subtask.DRIVER: RequestTableField.COMPLETED_DRIVER_EXECUTIONS, Subtask.QUERY: RequestTableField.COMPLETED_QUERY_EXECUTIONS, Subtask.CONVERTER: RequestTableField.COMPLETED_CONVERTER_EXECUTIONS, } self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, subtask_to_dynamo_field_name[subtask], 1) def lookup_cached_result(self) -> str: """ Retrieves the S3 key of an existing matrix result that corresponds to this request's request hash. Returns "" if no such result exists :return: S3 key of cached result """ results_bucket = S3Handler(os.environ['MATRIX_RESULTS_BUCKET']) objects = results_bucket.ls(f"{self.s3_results_prefix}/") if len(objects) > 0: return objects[0]['Key'] return "" def is_request_ready_for_conversion(self) -> bool: """ Checks whether the request has completed all queries and is ready for conversion :return: bool True if complete, else False """ request_state = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, key=self.request_id) queries_complete = ( request_state[RequestTableField.EXPECTED_QUERY_EXECUTIONS.value] == request_state[RequestTableField.COMPLETED_QUERY_EXECUTIONS.value]) return queries_complete def is_request_complete(self) -> bool: """ Checks whether the request has completed. :return: bool True if complete, else False """ results_bucket = S3Handler(os.environ['MATRIX_RESULTS_BUCKET']) return results_bucket.exists(self.s3_results_key) def complete_request(self, duration: float): """ Log the completion of a matrix request in CloudWatch Metrics :param duration: The time in seconds the request took to complete """ self.cloudwatch_handler.put_metric_data( metric_name=MetricName.CONVERSION_COMPLETION, metric_value=1) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.REQUEST_COMPLETION, metric_value=1) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': self.num_bundles_interval }, { 'Name': "Output Format", 'Value': self.format }, ]) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': self.num_bundles_interval }, ]) def log_error(self, message: str): """ Logs the latest error this request reported overwriting the previously logged error. :param message: str The error message to log """ logger.debug(message) self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.ERROR_MESSAGE, message) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.REQUEST_ERROR, metric_value=1) def write_batch_job_id_to_db(self, batch_job_id: str): """ Logs the batch job id for matrix conversion to state table """ self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.BATCH_JOB_ID, batch_job_id)
def test_invalidate_cache_entries(self, mock_put_metric_data): """ Setup: - Create four request ids mapping to two request hashes - Invalidate hash 1 (ids 1, 2) and id 3 - Verify ids 1, 2 and 3 have been invalidated - Verify id 4 has not been invalidated """ request_hash_1 = "test_hash_1" request_hash_2 = "test_hash_2" request_id_1 = "test_id_1" request_id_2 = "test_id_2" request_id_3 = "test_id_3" request_id_4 = "test_id_4" test_format = "test_format" test_content = "test_content" s3_key_1 = f"0/{request_hash_1}/{request_id_1}.{test_format}" s3_key_2 = f"0/{request_hash_1}/{request_id_2}.{test_format}" s3_key_3 = f"0/{request_hash_2}/{request_id_3}.{test_format}" s3_key_4 = f"0/{request_hash_2}/{request_id_4}.{test_format}" dynamo_handler = DynamoHandler() dynamo_handler.create_request_table_entry(request_id_1, test_format) dynamo_handler.create_request_table_entry(request_id_2, test_format) dynamo_handler.create_request_table_entry(request_id_3, test_format) dynamo_handler.create_request_table_entry(request_id_4, test_format) dynamo_handler.set_table_field_with_value( table=DynamoTable.REQUEST_TABLE, key=request_id_1, field_enum=RequestTableField.REQUEST_HASH, field_value=request_hash_1) dynamo_handler.set_table_field_with_value( table=DynamoTable.REQUEST_TABLE, key=request_id_2, field_enum=RequestTableField.REQUEST_HASH, field_value=request_hash_1) dynamo_handler.set_table_field_with_value( table=DynamoTable.REQUEST_TABLE, key=request_id_3, field_enum=RequestTableField.REQUEST_HASH, field_value=request_hash_2) dynamo_handler.set_table_field_with_value( table=DynamoTable.REQUEST_TABLE, key=request_id_4, field_enum=RequestTableField.REQUEST_HASH, field_value=request_hash_2) s3_results_bucket_handler = S3Handler( os.environ['MATRIX_RESULTS_BUCKET']) s3_results_bucket_handler.store_content_in_s3(s3_key_1, test_content) s3_results_bucket_handler.store_content_in_s3(s3_key_2, test_content) s3_results_bucket_handler.store_content_in_s3(s3_key_3, test_content) s3_results_bucket_handler.store_content_in_s3(s3_key_4, test_content) self.assertTrue(s3_results_bucket_handler.exists(s3_key_1)) self.assertTrue(s3_results_bucket_handler.exists(s3_key_2)) self.assertTrue(s3_results_bucket_handler.exists(s3_key_3)) self.assertTrue(s3_results_bucket_handler.exists(s3_key_4)) invalidate_cache_entries(request_ids=[request_id_3], request_hashes=[request_hash_1]) error_1 = dynamo_handler.get_table_item( table=DynamoTable.REQUEST_TABLE, key=request_id_1)[RequestTableField.ERROR_MESSAGE.value] error_2 = dynamo_handler.get_table_item( table=DynamoTable.REQUEST_TABLE, key=request_id_2)[RequestTableField.ERROR_MESSAGE.value] error_3 = dynamo_handler.get_table_item( table=DynamoTable.REQUEST_TABLE, key=request_id_3)[RequestTableField.ERROR_MESSAGE.value] error_4 = dynamo_handler.get_table_item( table=DynamoTable.REQUEST_TABLE, key=request_id_4)[RequestTableField.ERROR_MESSAGE.value] self.assertFalse(s3_results_bucket_handler.exists(s3_key_1)) self.assertFalse(s3_results_bucket_handler.exists(s3_key_2)) self.assertFalse(s3_results_bucket_handler.exists(s3_key_3)) self.assertTrue(s3_results_bucket_handler.exists(s3_key_4)) self.assertNotEqual(error_1, 0) self.assertNotEqual(error_2, 0) self.assertNotEqual(error_3, 0) self.assertEqual(error_4, 0)
class TestDynamoHandler(MatrixTestCaseUsingMockAWS): """ Environment variables are set in tests/unit/__init__.py """ def setUp(self): super(TestDynamoHandler, self).setUp() self.dynamo = boto3.resource("dynamodb", region_name=os.environ['AWS_DEFAULT_REGION']) self.data_version_table_name = os.environ['DYNAMO_DATA_VERSION_TABLE_NAME'] self.request_table_name = os.environ['DYNAMO_REQUEST_TABLE_NAME'] self.request_id = str(uuid.uuid4()) self.data_version = 1 self.format = "zarr" self.create_test_data_version_table() self.create_test_deployment_table() self.create_test_request_table() self.init_test_data_version_table() self.init_test_deployment_table() self.handler = DynamoHandler() def _get_data_version_table_response_and_entry(self): data_version_primary_key = self.handler._get_dynamo_table_primary_key_from_enum(DynamoTable.DATA_VERSION_TABLE) response = self.dynamo.batch_get_item( RequestItems={ self.data_version_table_name: { 'Keys': [{data_version_primary_key: self.data_version}] } } ) entry = response['Responses'][self.data_version_table_name][0] return response, entry def _get_request_table_response_and_entry(self): response = self.dynamo.batch_get_item( RequestItems={ self.request_table_name: { 'Keys': [{'RequestId': self.request_id}] } } ) entry = response['Responses'][self.request_table_name][0] return response, entry @mock.patch("matrix.common.v1_api_handler.V1ApiHandler.describe_filter") @mock.patch("matrix.common.date.get_datetime_now") def test_create_data_version_table_entry(self, mock_get_datetime_now, mock_describe_filter): stub_date = '2019-03-18T180907.136216Z' mock_get_datetime_now.return_value = stub_date stub_cell_counts = { 'test_project_uuid_1': 10, 'test_project_uuid_2': 100 } mock_describe_filter.return_value = { 'cell_counts': stub_cell_counts } self.handler.create_data_version_table_entry(self.data_version) response, entry = self._get_data_version_table_response_and_entry() metadata_schema_versions = {} for schema_name in SUPPORTED_METADATA_SCHEMA_VERSIONS: metadata_schema_versions[schema_name.value] = SUPPORTED_METADATA_SCHEMA_VERSIONS[schema_name] self.assertEqual(len(response['Responses'][self.data_version_table_name]), 1) self.assertTrue(all(field.value in entry for field in DataVersionTableField)) self.assertEqual(entry[DataVersionTableField.DATA_VERSION.value], self.data_version) self.assertEqual(entry[DataVersionTableField.CREATION_DATE.value], stub_date) self.assertEqual(entry[DataVersionTableField.PROJECT_CELL_COUNTS.value], stub_cell_counts) self.assertEqual(entry[DataVersionTableField.METADATA_SCHEMA_VERSIONS.value], metadata_schema_versions) @mock.patch("matrix.common.date.get_datetime_now") def test_create_request_table_entry(self, mock_get_datetime_now): stub_date = '2019-03-18T180907.136216Z' mock_get_datetime_now.return_value = stub_date self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(len(response['Responses'][self.request_table_name]), 1) self.assertTrue(all(field.value in entry for field in RequestTableField)) self.assertEqual(entry[RequestTableField.FORMAT.value], self.format) self.assertEqual(entry[RequestTableField.METADATA_FIELDS.value], DEFAULT_FIELDS) self.assertEqual(entry[RequestTableField.FEATURE.value], "gene") self.assertEqual(entry[RequestTableField.DATA_VERSION.value], 0) self.assertEqual(entry[RequestTableField.REQUEST_HASH.value], "N/A") self.assertEqual(entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1) self.assertEqual(entry[RequestTableField.EXPECTED_CONVERTER_EXECUTIONS.value], 1) self.assertEqual(entry[RequestTableField.CREATION_DATE.value], stub_date) def test_increment_table_field_request_table_path(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0) self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS self.handler.increment_table_field(DynamoTable.REQUEST_TABLE, self.request_id, field_enum, 5) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 5) self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) def test_set_table_field_with_value(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A") field_enum = RequestTableField.BATCH_JOB_ID self.handler.set_table_field_with_value(DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123") response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "123-123") def test_increment_field(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0) self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) key_dict = {"RequestId": self.request_id} field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS self.handler._increment_field(self.handler._get_dynamo_table_resource_from_enum(DynamoTable.REQUEST_TABLE), key_dict, field_enum, 15) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15) self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) def test_set_field(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A") key_dict = {"RequestId": self.request_id} field_enum = RequestTableField.BATCH_JOB_ID self.handler._set_field(self.handler._get_dynamo_table_resource_from_enum(DynamoTable.REQUEST_TABLE), key_dict, field_enum, "123-123") response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "123-123") def test_get_request_table_entry(self): self.handler.create_request_table_entry(self.request_id, self.format) entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id) self.assertEqual(entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1) key_dict = {"RequestId": self.request_id} field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS self.handler._increment_field(self.handler._get_dynamo_table_resource_from_enum(DynamoTable.REQUEST_TABLE), key_dict, field_enum, 15) entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id) self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15) def test_get_table_item(self): self.assertRaises(MatrixException, self.handler.get_table_item, DynamoTable.REQUEST_TABLE, key=self.request_id) self.handler.create_request_table_entry(self.request_id, self.format) entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id) self.assertEqual(entry[RequestTableField.ROW_COUNT.value], 0) def test_filter_table_items(self): items = self.handler.filter_table_items( table=DynamoTable.REQUEST_TABLE, attrs={RequestTableField.REQUEST_HASH.value: "N/A"} ) self.assertEqual(len(items), 0) self.handler.create_request_table_entry(self.request_id, self.format) self.handler.create_request_table_entry(str(uuid.uuid4()), "test_format") items = self.handler.filter_table_items( table=DynamoTable.REQUEST_TABLE, attrs={RequestTableField.REQUEST_HASH.value: "N/A"} ) self.assertEqual(len(items), 2) items = self.handler.filter_table_items( table=DynamoTable.REQUEST_TABLE, attrs={RequestTableField.REQUEST_HASH.value: "N/A", RequestTableField.FORMAT.value: self.format} ) self.assertEqual(len(items), 1) self.assertEqual(items[0][RequestTableField.REQUEST_ID.value], self.request_id)
class TestDynamoHandler(MatrixTestCaseUsingMockAWS): """ Environment variables are set in tests/unit/__init__.py """ def setUp(self): super(TestDynamoHandler, self).setUp() self.dynamo = boto3.resource( "dynamodb", region_name=os.environ['AWS_DEFAULT_REGION']) self.request_table_name = os.environ['DYNAMO_REQUEST_TABLE_NAME'] self.request_id = str(uuid.uuid4()) self.format = "zarr" self.create_test_request_table() self.handler = DynamoHandler() def _get_request_table_response_and_entry(self): response = self.dynamo.batch_get_item(RequestItems={ self.request_table_name: { 'Keys': [{ 'RequestId': self.request_id }] } }) entry = response['Responses'][self.request_table_name][0] return response, entry @mock.patch("matrix.common.date.get_datetime_now") def test_create_request_table_entry(self, mock_get_datetime_now): stub_date = '2019-03-18T180907.136216Z' mock_get_datetime_now.return_value = stub_date self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(len(response['Responses'][self.request_table_name]), 1) self.assertTrue( all(field.value in entry for field in RequestTableField)) self.assertEqual(entry[RequestTableField.FORMAT.value], self.format) self.assertEqual( entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1) self.assertEqual( entry[RequestTableField.EXPECTED_CONVERTER_EXECUTIONS.value], 1) self.assertEqual(entry[RequestTableField.CREATION_DATE.value], stub_date) def test_increment_table_field_request_table_path(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual( entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0) self.assertEqual( entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS self.handler.increment_table_field(DynamoTable.REQUEST_TABLE, self.request_id, field_enum, 5) response, entry = self._get_request_table_response_and_entry() self.assertEqual( entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 5) self.assertEqual( entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) def test_set_table_field_with_value(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A") field_enum = RequestTableField.BATCH_JOB_ID self.handler.set_table_field_with_value(DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123") response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "123-123") def test_increment_field(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual( entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0) self.assertEqual( entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) key_dict = {"RequestId": self.request_id} field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS self.handler._increment_field(self.handler._request_table, key_dict, field_enum, 15) response, entry = self._get_request_table_response_and_entry() self.assertEqual( entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15) self.assertEqual( entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0) def test_set_field(self): self.handler.create_request_table_entry(self.request_id, self.format) response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A") key_dict = {"RequestId": self.request_id} field_enum = RequestTableField.BATCH_JOB_ID self.handler._set_field(self.handler._request_table, key_dict, field_enum, "123-123") response, entry = self._get_request_table_response_and_entry() self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "123-123") def test_get_request_table_entry(self): self.handler.create_request_table_entry(self.request_id, self.format) entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, request_id=self.request_id) self.assertEqual( entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1) key_dict = {"RequestId": self.request_id} field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS self.handler._increment_field(self.handler._request_table, key_dict, field_enum, 15) entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, request_id=self.request_id) self.assertEqual( entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15) def test_get_table_item(self): self.assertRaises(MatrixException, self.handler.get_table_item, DynamoTable.REQUEST_TABLE, request_id=self.request_id) self.handler.create_request_table_entry(self.request_id, self.format) entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, request_id=self.request_id) self.assertEqual(entry[RequestTableField.ROW_COUNT.value], 0)
class RequestTracker: """ Provides an interface for tracking a request's parameters and state. """ def __init__(self, request_id: str): Logging.set_correlation_id(logger, request_id) self.request_id = request_id self._num_bundles = None self._format = None self.dynamo_handler = DynamoHandler() self.cloudwatch_handler = CloudwatchHandler() self.batch_handler = BatchHandler() @property def is_initialized(self) -> bool: try: self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, request_id=self.request_id) except MatrixException: return False return True @property def num_bundles(self) -> int: """ The number of bundles in the request. :return: int Number of bundles """ if not self._num_bundles: self._num_bundles =\ self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, request_id=self.request_id)[RequestTableField.NUM_BUNDLES.value] return self._num_bundles @property def num_bundles_interval(self) -> str: """ Returns the interval string that num_bundles corresponds to. :return: the interval string e.g. "0-499" """ interval_size = 500 index = int(self.num_bundles / interval_size) return f"{index * interval_size}-{(index * interval_size) + interval_size - 1}" @property def format(self) -> str: """ The request's user specified output file format of the resultant expression matrix. :return: str The file format (one of MatrixFormat) """ if not self._format: self._format =\ self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE, request_id=self.request_id)[RequestTableField.FORMAT.value] return self._format @property def batch_job_id(self) -> str: """ The batch job id for matrix conversion corresponding with a request. :return: str The batch job id """ table_item = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, request_id=self.request_id) batch_job_id = table_item.get(RequestTableField.BATCH_JOB_ID.value) if not batch_job_id or batch_job_id == "N/A": return None else: return batch_job_id @property def batch_job_status(self) -> str: """ The batch job status for matrix conversion corresponding with a request. :return: str The batch job status """ status = None if self.batch_job_id: status = self.batch_handler.get_batch_job_status(self.batch_job_id) return status @property def creation_date(self) -> str: """ The creation date of matrix service request. :return: str creation date """ return self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, request_id=self.request_id)[RequestTableField.CREATION_DATE.value] @property def timeout(self) -> bool: timeout = date.to_datetime( self.creation_date) < date.get_datetime_now() - timedelta(hours=12) if timeout: self.log_error( "This request has timed out after 12 hours." "Please try again by resubmitting the POST request.") return timeout @property def error(self) -> str: """ The user-friendly message describing the latest error the request raised. :return: str The error message if one exists, else empty string """ error = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, request_id=self.request_id)[RequestTableField.ERROR_MESSAGE.value] return error if error else "" def initialize_request(self, fmt: str) -> None: """Initialize the request id in the request state table. Put request metric to cloudwatch. :param format: Request output format for matrix conversion """ self.dynamo_handler.create_request_table_entry(self.request_id, fmt) self.cloudwatch_handler.put_metric_data(metric_name=MetricName.REQUEST, metric_value=1) def expect_subtask_execution(self, subtask: Subtask): """ Expect the execution of 1 Subtask by tracking it in DynamoDB. A Subtask is executed either by a Lambda or AWS Batch. :param subtask: The expected Subtask to be executed. """ subtask_to_dynamo_field_name = { Subtask.DRIVER: RequestTableField.EXPECTED_DRIVER_EXECUTIONS, Subtask.CONVERTER: RequestTableField.EXPECTED_CONVERTER_EXECUTIONS, } self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, subtask_to_dynamo_field_name[subtask], 1) def complete_subtask_execution(self, subtask: Subtask): """ Counts the completed execution of 1 Subtask in DynamoDB. A Subtask is executed either by a Lambda or AWS Batch. :param subtask: The executed Subtask. """ subtask_to_dynamo_field_name = { Subtask.DRIVER: RequestTableField.COMPLETED_DRIVER_EXECUTIONS, Subtask.QUERY: RequestTableField.COMPLETED_QUERY_EXECUTIONS, Subtask.CONVERTER: RequestTableField.COMPLETED_CONVERTER_EXECUTIONS, } self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, subtask_to_dynamo_field_name[subtask], 1) def is_request_complete(self) -> bool: """ Checks whether the request has completed, i.e. if all expected reducers and converters have completed. :return: bool True if complete, else False """ request_state = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, request_id=self.request_id) queries_complete = ( request_state[RequestTableField.EXPECTED_QUERY_EXECUTIONS.value] == request_state[RequestTableField.COMPLETED_QUERY_EXECUTIONS.value]) converter_complete = (request_state[ RequestTableField.EXPECTED_CONVERTER_EXECUTIONS. value] == request_state[ RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value]) return queries_complete and converter_complete def is_request_ready_for_conversion(self) -> bool: """ Checks whether the request has completed all queries and is ready for conversion :return: bool True if complete, else False """ request_state = self.dynamo_handler.get_table_item( DynamoTable.REQUEST_TABLE, request_id=self.request_id) queries_complete = ( request_state[RequestTableField.EXPECTED_QUERY_EXECUTIONS.value] == request_state[RequestTableField.COMPLETED_QUERY_EXECUTIONS.value]) return queries_complete def complete_request(self, duration: float): """ Log the completion of a matrix request in CloudWatch Metrics :param duration: The time in seconds the request took to complete """ self.cloudwatch_handler.put_metric_data( metric_name=MetricName.CONVERSION_COMPLETION, metric_value=1) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.REQUEST_COMPLETION, metric_value=1) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': self.num_bundles_interval }, { 'Name': "Output Format", 'Value': self.format }, ]) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': self.num_bundles_interval }, ]) def log_error(self, message: str): """ Logs the latest error this request reported overwriting the previously logged error. :param message: str The error message to log """ logger.debug(message) self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.ERROR_MESSAGE, message) self.cloudwatch_handler.put_metric_data( metric_name=MetricName.REQUEST_ERROR, metric_value=1) def write_batch_job_id_to_db(self, batch_job_id: str): """ Logs the batch job id for matrix conversion to state table """ self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.BATCH_JOB_ID, batch_job_id)
class TestRequestTracker(MatrixTestCaseUsingMockAWS): @mock.patch("matrix.common.date.get_datetime_now") def setUp(self, mock_get_datetime_now): super(TestRequestTracker, self).setUp() self.stub_date = '2019-03-18T180907.136216Z' mock_get_datetime_now.return_value = self.stub_date self.request_id = str(uuid.uuid4()) self.request_tracker = RequestTracker(self.request_id) self.dynamo_handler = DynamoHandler() self.create_test_request_table() self.dynamo_handler.create_request_table_entry(self.request_id, "test_format") def test_format(self): self.assertEqual(self.request_tracker.format, "test_format") def test_batch_job_id(self): self.assertEqual(self.request_tracker.batch_job_id, None) field_enum = RequestTableField.BATCH_JOB_ID self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123") self.assertEqual(self.request_tracker.batch_job_id, "123-123") @mock.patch( "matrix.common.aws.batch_handler.BatchHandler.get_batch_job_status") def test_batch_job_status(self, mock_get_job_status): mock_get_job_status.return_value = "FAILED" field_enum = RequestTableField.BATCH_JOB_ID self.dynamo_handler.set_table_field_with_value( DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123") self.assertEqual(self.request_tracker.batch_job_status, "FAILED") @mock.patch( "matrix.common.request.request_tracker.RequestTracker.num_bundles", new_callable=mock.PropertyMock) def test_num_bundles_interval(self, mock_num_bundles): mock_num_bundles.return_value = 0 self.assertEqual(self.request_tracker.num_bundles_interval, "0-499") mock_num_bundles.return_value = 1 self.assertEqual(self.request_tracker.num_bundles_interval, "0-499") mock_num_bundles.return_value = 500 self.assertEqual(self.request_tracker.num_bundles_interval, "500-999") mock_num_bundles.return_value = 1234 self.assertEqual(self.request_tracker.num_bundles_interval, "1000-1499") def test_creation_date(self): self.assertEqual(self.request_tracker.creation_date, self.stub_date) @mock.patch( "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data" ) def test_error(self, mock_cw_put): self.assertEqual(self.request_tracker.error, "") self.request_tracker.log_error("test error") self.assertEqual(self.request_tracker.error, "test error") mock_cw_put.assert_called_once_with( metric_name=MetricName.REQUEST_ERROR, metric_value=1) @mock.patch( "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data" ) @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.create_request_table_entry" ) def test_initialize_request(self, mock_create_request_table_entry, mock_create_cw_metric): self.request_tracker.initialize_request("test_format") mock_create_request_table_entry.assert_called_once_with( self.request_id, "test_format") mock_create_cw_metric.assert_called_once() @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field") def test_expect_subtask_execution(self, mock_increment_table_field): self.request_tracker.expect_subtask_execution(Subtask.DRIVER) mock_increment_table_field.assert_called_once_with( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.EXPECTED_DRIVER_EXECUTIONS, 1) @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field") def test_complete_subtask_execution(self, mock_increment_table_field): self.request_tracker.complete_subtask_execution(Subtask.DRIVER) mock_increment_table_field.assert_called_once_with( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.COMPLETED_DRIVER_EXECUTIONS, 1) def test_is_request_complete(self): self.assertFalse(self.request_tracker.is_request_complete()) self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.COMPLETED_CONVERTER_EXECUTIONS, 1) self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.COMPLETED_QUERY_EXECUTIONS, 3) self.assertTrue(self.request_tracker.is_request_complete()) def test_is_request_ready_for_conversion(self): self.assertFalse( self.request_tracker.is_request_ready_for_conversion()) self.dynamo_handler.increment_table_field( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.COMPLETED_QUERY_EXECUTIONS, 3) self.assertTrue(self.request_tracker.is_request_ready_for_conversion()) @mock.patch( "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data" ) def test_complete_request(self, mock_cw_put): duration = 1 self.request_tracker.complete_request(duration) expected_calls = [ mock.call(metric_name=MetricName.CONVERSION_COMPLETION, metric_value=1), mock.call(metric_name=MetricName.REQUEST_COMPLETION, metric_value=1), mock.call(metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': mock.ANY }, { 'Name': "Output Format", 'Value': mock.ANY }, ]), mock.call(metric_name=MetricName.DURATION, metric_value=duration, metric_dimensions=[ { 'Name': "Number of Bundles", 'Value': mock.ANY }, ]) ] mock_cw_put.assert_has_calls(expected_calls) @mock.patch( "matrix.common.request.request_tracker.RequestTracker.log_error") @mock.patch( "matrix.common.request.request_tracker.RequestTracker.creation_date", new_callable=mock.PropertyMock) def test_timeout(self, mock_creation_date, mock_log_error): # no timeout mock_creation_date.return_value = date.to_string( date.get_datetime_now() - timedelta(hours=11, minutes=59)) self.assertFalse(self.request_tracker.timeout) # timeout mock_creation_date.return_value = date.to_string( date.get_datetime_now() - timedelta(hours=12, minutes=1)) self.assertTrue(self.request_tracker.timeout) mock_log_error.assert_called_once() @mock.patch( "matrix.common.aws.dynamo_handler.DynamoHandler.set_table_field_with_value" ) def test_write_batch_job_id_to_db(self, mock_set_table_field_with_value): self.request_tracker.write_batch_job_id_to_db("123-123") mock_set_table_field_with_value.assert_called_once_with( DynamoTable.REQUEST_TABLE, self.request_id, RequestTableField.BATCH_JOB_ID, "123-123")