コード例 #1
0
class TestRequestTracker(MatrixTestCaseUsingMockAWS):
    @mock.patch("matrix.common.date.get_datetime_now")
    def setUp(self, mock_get_datetime_now):
        super(TestRequestTracker, self).setUp()
        self.stub_date = '2019-03-18T180907.136216Z'
        mock_get_datetime_now.return_value = self.stub_date

        self.request_id = str(uuid.uuid4())
        self.request_tracker = RequestTracker(self.request_id)
        self.dynamo_handler = DynamoHandler()

        self.create_test_data_version_table()
        self.create_test_deployment_table()
        self.create_test_request_table()
        self.create_s3_results_bucket()

        self.init_test_data_version_table()
        self.init_test_deployment_table()

        self.dynamo_handler.create_request_table_entry(
            self.request_id, "test_format", ["test_field_1", "test_field_2"],
            "test_feature")

    def test_is_initialized(self):
        self.assertTrue(self.request_tracker.is_initialized)

        new_request_tracker = RequestTracker("test_uuid")
        self.assertFalse(new_request_tracker.is_initialized)

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.generate_request_hash"
    )
    def test_request_hash(self, mock_generate_request_hash):
        with self.subTest("Test skip generation in API deployments:"):
            os.environ['MATRIX_VERSION'] = "test_version"
            self.assertEqual(self.request_tracker.request_hash, "N/A")
            mock_generate_request_hash.assert_not_called()

            stored_request_hash = self.dynamo_handler.get_table_item(
                DynamoTable.REQUEST_TABLE,
                key=self.request_id)[RequestTableField.REQUEST_HASH.value]

            self.assertEqual(self.request_tracker._request_hash, "N/A")
            self.assertEqual(stored_request_hash, "N/A")

            del os.environ['MATRIX_VERSION']

        with self.subTest(
                "Test generation and storage in Dynamo on first access"):
            mock_generate_request_hash.return_value = "test_hash"
            self.assertEqual(self.request_tracker.request_hash, "test_hash")
            mock_generate_request_hash.assert_called_once()

            stored_request_hash = self.dynamo_handler.get_table_item(
                DynamoTable.REQUEST_TABLE,
                key=self.request_id)[RequestTableField.REQUEST_HASH.value]

            self.assertEqual(self.request_tracker._request_hash, "test_hash")
            self.assertEqual(stored_request_hash, "test_hash")

        with self.subTest("Test immediate retrieval on future accesses"):
            self.assertEqual(self.request_tracker.request_hash, "test_hash")
            mock_generate_request_hash.assert_called_once()

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.request_hash",
        new_callable=mock.PropertyMock)
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.data_version",
        new_callable=mock.PropertyMock)
    def test_s3_results_prefix(self, mock_data_version, mock_request_hash):
        mock_data_version.return_value = "test_data_version"
        mock_request_hash.return_value = "test_request_hash"

        self.assertEqual(self.request_tracker.s3_results_prefix,
                         "test_data_version/test_request_hash")

    @mock.patch("matrix.common.request.request_tracker.RequestTracker.format",
                new_callable=mock.PropertyMock)
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.request_hash",
        new_callable=mock.PropertyMock)
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.data_version",
        new_callable=mock.PropertyMock)
    def test_s3_results_key(self, mock_data_version, mock_request_hash,
                            mock_format):
        mock_data_version.return_value = "test_data_version"
        mock_request_hash.return_value = "test_request_hash"
        mock_format.return_value = "loom"

        self.assertEqual(
            self.request_tracker.s3_results_key,
            f"test_data_version/test_request_hash/{self.request_id}.loom")

        mock_format.return_value = "csv"
        self.assertEqual(
            self.request_tracker.s3_results_key,
            f"test_data_version/test_request_hash/{self.request_id}.csv.zip")

        mock_format.return_value = "mtx"
        self.assertEqual(
            self.request_tracker.s3_results_key,
            f"test_data_version/test_request_hash/{self.request_id}.mtx.zip")

    @mock.patch("matrix.common.aws.dynamo_handler.DynamoHandler.get_table_item"
                )
    def test_data_version(self, mock_get_table_item):
        mock_get_table_item.return_value = {
            RequestTableField.DATA_VERSION.value: 0
        }

        with self.subTest("Test Dynamo read on first access"):
            self.assertEqual(self.request_tracker.data_version, 0)
            mock_get_table_item.assert_called_once()

        with self.subTest("Test cached access on successive reads"):
            self.assertEqual(self.request_tracker.data_version, 0)
            mock_get_table_item.assert_called_once()

    def test_format(self):
        self.assertEqual(self.request_tracker.format, "test_format")

    def test_metadata_fields(self):
        self.assertEqual(self.request_tracker.metadata_fields,
                         ["test_field_1", "test_field_2"])

    def test_feature(self):
        self.assertEqual(self.request_tracker.feature, "test_feature")

    def test_batch_job_id(self):
        self.assertEqual(self.request_tracker.batch_job_id, None)

        field_enum = RequestTableField.BATCH_JOB_ID
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123")
        self.assertEqual(self.request_tracker.batch_job_id, "123-123")

    @mock.patch(
        "matrix.common.aws.batch_handler.BatchHandler.get_batch_job_status")
    def test_batch_job_status(self, mock_get_job_status):
        mock_get_job_status.return_value = "FAILED"
        field_enum = RequestTableField.BATCH_JOB_ID
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123")

        self.assertEqual(self.request_tracker.batch_job_status, "FAILED")

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.num_bundles",
        new_callable=mock.PropertyMock)
    def test_num_bundles_interval(self, mock_num_bundles):
        mock_num_bundles.return_value = 0
        self.assertEqual(self.request_tracker.num_bundles_interval, "0-499")

        mock_num_bundles.return_value = 1
        self.assertEqual(self.request_tracker.num_bundles_interval, "0-499")

        mock_num_bundles.return_value = 500
        self.assertEqual(self.request_tracker.num_bundles_interval, "500-999")

        mock_num_bundles.return_value = 1234
        self.assertEqual(self.request_tracker.num_bundles_interval,
                         "1000-1499")

    def test_creation_date(self):
        self.assertEqual(self.request_tracker.creation_date, self.stub_date)

    @mock.patch(
        "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data"
    )
    def test_error(self, mock_cw_put):
        self.assertEqual(self.request_tracker.error, "")

        self.request_tracker.log_error("test error")
        self.assertEqual(self.request_tracker.error, "test error")
        mock_cw_put.assert_called_once_with(
            metric_name=MetricName.REQUEST_ERROR, metric_value=1)

    @mock.patch(
        "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data"
    )
    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.create_request_table_entry"
    )
    def test_initialize_request(self, mock_create_request_table_entry,
                                mock_create_cw_metric):
        self.request_tracker.initialize_request("test_format")

        mock_create_request_table_entry.assert_called_once_with(
            self.request_id, "test_format", DEFAULT_FIELDS, DEFAULT_FEATURE)
        mock_create_cw_metric.assert_called_once()

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.metadata_fields",
        new_callable=mock.PropertyMock)
    @mock.patch(
        "matrix.common.query.cell_query_results_reader.CellQueryResultsReader.load_results"
    )
    @mock.patch(
        "matrix.common.query.query_results_reader.QueryResultsReader._parse_manifest"
    )
    def test_generate_request_hash(self, mock_parse_manifest,
                                   mock_load_results, mock_metadata_fields):
        mock_load_results.return_value = pandas.DataFrame(
            index=["test_cell_key_1", "test_cell_key_2"])
        mock_metadata_fields.return_value = ["test_field_1", "test_field_2"]

        h = hashlib.md5()
        h.update(self.request_tracker.feature.encode())
        h.update(self.request_tracker.format.encode())
        h.update("test_field_1".encode())
        h.update("test_field_2".encode())
        h.update("test_cell_key_1".encode())
        h.update("test_cell_key_2".encode())

        self.assertEqual(self.request_tracker.generate_request_hash(),
                         h.hexdigest())

    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field")
    def test_expect_subtask_execution(self, mock_increment_table_field):
        self.request_tracker.expect_subtask_execution(Subtask.DRIVER)

        mock_increment_table_field.assert_called_once_with(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.EXPECTED_DRIVER_EXECUTIONS, 1)

    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field")
    def test_complete_subtask_execution(self, mock_increment_table_field):
        self.request_tracker.complete_subtask_execution(Subtask.DRIVER)

        mock_increment_table_field.assert_called_once_with(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.COMPLETED_DRIVER_EXECUTIONS, 1)

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.s3_results_prefix",
        new_callable=mock.PropertyMock)
    def test_lookup_cached_result(self, mock_s3_results_prefix):
        mock_s3_results_prefix.return_value = "test_prefix"
        s3_handler = S3Handler(os.environ['MATRIX_RESULTS_BUCKET'])

        with self.subTest("Do not match in S3 'directories'"):
            s3_handler.store_content_in_s3("test_prefix", "test_content")
            self.assertEqual(self.request_tracker.lookup_cached_result(), "")

        with self.subTest("Successfully retrieve a result key"):
            s3_handler.store_content_in_s3("test_prefix/test_result_1",
                                           "test_content")
            s3_handler.store_content_in_s3("test_prefix/test_result_2",
                                           "test_content")
            self.assertEqual(self.request_tracker.lookup_cached_result(),
                             "test_prefix/test_result_1")

    def test_is_request_complete(self):
        self.assertFalse(self.request_tracker.is_request_complete())

        s3_handler = S3Handler(os.environ['MATRIX_RESULTS_BUCKET'])

        s3_handler.store_content_in_s3(
            f"{self.request_tracker.s3_results_key}/{self.request_id}.{self.request_tracker.format}",
            "")

        self.assertTrue(self.request_tracker.is_request_complete())

    def test_is_request_ready_for_conversion(self):
        self.assertFalse(
            self.request_tracker.is_request_ready_for_conversion())
        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.COMPLETED_QUERY_EXECUTIONS, 3)
        self.assertTrue(self.request_tracker.is_request_ready_for_conversion())

    @mock.patch(
        "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data"
    )
    def test_complete_request(self, mock_cw_put):
        duration = 1

        self.request_tracker.complete_request(duration)

        expected_calls = [
            mock.call(metric_name=MetricName.CONVERSION_COMPLETION,
                      metric_value=1),
            mock.call(metric_name=MetricName.REQUEST_COMPLETION,
                      metric_value=1),
            mock.call(metric_name=MetricName.DURATION,
                      metric_value=duration,
                      metric_dimensions=[
                          {
                              'Name': "Number of Bundles",
                              'Value': mock.ANY
                          },
                          {
                              'Name': "Output Format",
                              'Value': mock.ANY
                          },
                      ]),
            mock.call(metric_name=MetricName.DURATION,
                      metric_value=duration,
                      metric_dimensions=[
                          {
                              'Name': "Number of Bundles",
                              'Value': mock.ANY
                          },
                      ])
        ]
        mock_cw_put.assert_has_calls(expected_calls)

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.log_error")
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.creation_date",
        new_callable=mock.PropertyMock)
    @mock.patch("matrix.common.aws.s3_handler.S3Handler.exists")
    def test_is_expired(self, mock_exists, mock_creation_date, mock_log_error):
        with self.subTest("Expired"):
            mock_exists.return_value = False
            mock_creation_date.return_value = date.to_string(
                date.get_datetime_now() - timedelta(days=30, minutes=1))

            self.assertTrue(self.request_tracker.is_expired)
            mock_log_error.assert_called_once()
            mock_log_error.reset_mock()

        with self.subTest(
                "Not expired. Matrix DNE but not past expiration date"):
            mock_exists.return_value = False
            mock_creation_date.return_value = date.to_string(
                date.get_datetime_now() - timedelta(days=29))

            self.assertFalse(self.request_tracker.is_expired)
            mock_log_error.assert_not_called()

        with self.subTest("Not expired. Matrix exists"):
            mock_exists.return_value = True

            self.assertFalse(self.request_tracker.is_expired)
            mock_log_error.assert_not_called()

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.log_error")
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.creation_date",
        new_callable=mock.PropertyMock)
    def test_timeout(self, mock_creation_date, mock_log_error):
        # no timeout
        mock_creation_date.return_value = date.to_string(
            date.get_datetime_now() - timedelta(hours=35, minutes=59))
        self.assertFalse(self.request_tracker.timeout)

        # timeout
        mock_creation_date.return_value = date.to_string(
            date.get_datetime_now() - timedelta(hours=36, minutes=1))
        self.assertTrue(self.request_tracker.timeout)
        mock_log_error.assert_called_once()

    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.set_table_field_with_value"
    )
    def test_write_batch_job_id_to_db(self, mock_set_table_field_with_value):
        self.request_tracker.write_batch_job_id_to_db("123-123")
        mock_set_table_field_with_value.assert_called_once_with(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.BATCH_JOB_ID, "123-123")
コード例 #2
0
class RequestTracker:
    """
    Provides an interface for tracking a request's parameters and state.
    """
    def __init__(self, request_id: str):
        Logging.set_correlation_id(logger, request_id)

        self.request_id = request_id
        self._request_hash = "N/A"
        self._data_version = None
        self._num_bundles = None
        self._format = None
        self._metadata_fields = None
        self._feature = None

        self.dynamo_handler = DynamoHandler()
        self.cloudwatch_handler = CloudwatchHandler()
        self.batch_handler = BatchHandler()

    @property
    def is_initialized(self) -> bool:
        try:
            self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                               key=self.request_id)
        except MatrixException:
            return False

        return True

    @property
    def request_hash(self) -> str:
        """
        Unique hash generated using request parameters.
        If a request hash does not exist, one will be attempted to be generated.
        :return: str Request hash
        """
        if self._request_hash == "N/A":
            self._request_hash = self.dynamo_handler.get_table_item(
                DynamoTable.REQUEST_TABLE,
                key=self.request_id)[RequestTableField.REQUEST_HASH.value]

            # Do not generate request hash in API requests to avoid timeouts.
            # Presence of MATRIX_VERSION indicates API deployment.
            if self._request_hash == "N/A" and not os.getenv('MATRIX_VERSION'):
                try:
                    self._request_hash = self.generate_request_hash()
                    self.dynamo_handler.set_table_field_with_value(
                        DynamoTable.REQUEST_TABLE, self.request_id,
                        RequestTableField.REQUEST_HASH, self._request_hash)
                except MatrixQueryResultsNotFound as e:
                    logger.warning(f"Failed to generate a request hash. {e}")

        return self._request_hash

    @property
    def s3_results_prefix(self) -> str:
        """
        The S3 prefix where results for this request hash are stored in the results bucket.
        :return: str S3 prefix
        """
        return f"{self.data_version}/{self.request_hash}"

    @property
    def s3_results_key(self) -> str:
        """
        The S3 key where matrix results for this request are stored in the results bucket.
        :return: str S3 key
        """
        is_compressed = self.format == MatrixFormat.CSV.value or self.format == MatrixFormat.MTX.value

        return f"{self.data_version}/{self.request_hash}/{self.request_id}.{self.format}" + \
               (".zip" if is_compressed else "")

    @property
    def data_version(self) -> int:
        """
        The Redshift data version this request is generated on.
        :return: int Data version
        """
        if self._data_version is None:
            self._data_version = \
                self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                                   key=self.request_id)[RequestTableField.DATA_VERSION.value]

        return self._data_version

    @property
    def num_bundles(self) -> int:
        """
        The number of bundles in the request.
        :return: int Number of bundles
        """
        if not self._num_bundles:
            self._num_bundles = \
                self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                                   key=self.request_id)[RequestTableField.NUM_BUNDLES.value]
        return self._num_bundles

    @property
    def num_bundles_interval(self) -> str:
        """
        Returns the interval string that num_bundles corresponds to.
        :return: the interval string e.g. "0-499"
        """
        interval_size = 500

        index = int(self.num_bundles / interval_size)
        return f"{index * interval_size}-{(index * interval_size) + interval_size - 1}"

    @property
    def format(self) -> str:
        """
        The request's user specified output file format of the resultant expression matrix.
        :return: str The file format (one of MatrixFormat)
        """
        if not self._format:
            self._format = \
                self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                                   key=self.request_id)[RequestTableField.FORMAT.value]
        return self._format

    @property
    def metadata_fields(self) -> list:
        """
        The request's user-specified list of metadata fields to include in the resultant expression matrix.
        :return:  list List of metadata fields
        """
        if not self._metadata_fields:
            self._metadata_fields = \
                self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                                   key=self.request_id)[RequestTableField.METADATA_FIELDS.value]
        return self._metadata_fields

    @property
    def feature(self) -> str:
        """
        The request's user-specified feature type (gene|transcript) of the resultant expression matrix.
        :return: str Feature (gene|transcript)
        """
        if not self._feature:
            self._feature = \
                self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                                   key=self.request_id)[RequestTableField.FEATURE.value]
        return self._feature

    @property
    def batch_job_id(self) -> str:
        """
        The batch job id for matrix conversion corresponding with a request.
        :return: str The batch job id
        """
        table_item = self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE, key=self.request_id)
        batch_job_id = table_item.get(RequestTableField.BATCH_JOB_ID.value)
        if not batch_job_id or batch_job_id == "N/A":
            return None
        else:
            return batch_job_id

    @property
    def batch_job_status(self) -> str:
        """
        The batch job status for matrix conversion corresponding with a request.
        :return: str The batch job status
        """
        status = None
        if self.batch_job_id:
            status = self.batch_handler.get_batch_job_status(self.batch_job_id)
        return status

    @property
    def creation_date(self) -> str:
        """
        The creation date of matrix service request.
        :return: str creation date
        """
        return self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE,
            key=self.request_id)[RequestTableField.CREATION_DATE.value]

    @property
    def is_expired(self):
        """
        Whether or not the request has expired and the matrix in S3 has been deleted.
        :return: bool
        """
        s3_results_bucket_handler = S3Handler(
            os.environ['MATRIX_RESULTS_BUCKET'])
        is_past_expiration = date.to_datetime(
            self.creation_date) < date.get_datetime_now() - timedelta(days=30)
        is_expired = not s3_results_bucket_handler.exists(
            self.s3_results_key) and is_past_expiration

        if is_expired:
            self.log_error(
                "This request has expired after 30 days and is no longer available for download. "
                "A new matrix can be generated by resubmitting the POST request to /v1/matrix."
            )

        return is_expired

    @property
    def timeout(self) -> bool:
        timeout = date.to_datetime(
            self.creation_date) < date.get_datetime_now() - timedelta(hours=36)
        if timeout:
            self.log_error(
                "This request has timed out after 12 hours."
                "Please try again by resubmitting the POST request.")
        return timeout

    @property
    def error(self) -> str:
        """
        The user-friendly message describing the latest error the request raised.
        :return: str The error message if one exists, else empty string
        """
        error = self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE,
            key=self.request_id)[RequestTableField.ERROR_MESSAGE.value]
        return error if error else ""

    def initialize_request(self,
                           fmt: str,
                           metadata_fields: list = DEFAULT_FIELDS,
                           feature: str = DEFAULT_FEATURE) -> None:
        """Initialize the request id in the request state table. Put request metric to cloudwatch.
        :param fmt: Request output format for matrix conversion
        :param metadata_fields: Metadata fields to include in expression matrix
        :param feature: Feature type to generate expression counts for (one of MatrixFeature)
        """
        self.dynamo_handler.create_request_table_entry(self.request_id, fmt,
                                                       metadata_fields,
                                                       feature)
        self.cloudwatch_handler.put_metric_data(metric_name=MetricName.REQUEST,
                                                metric_value=1)

    def generate_request_hash(self) -> str:
        """
        Generates a request hash uniquely identifying a request by its input parameters.
        Requires cell query results to exist, else raises MatrixQueryResultsNotFound.
        :return: str Request hash
        """
        cell_manifest_key = f"s3://{os.environ['MATRIX_QUERY_RESULTS_BUCKET']}/{self.request_id}/cell_metadata_manifest"
        reader = CellQueryResultsReader(cell_manifest_key)
        cell_df = reader.load_results()
        cellkeys = cell_df.index

        h = hashlib.md5()
        h.update(self.feature.encode())
        h.update(self.format.encode())

        for field in self.metadata_fields:
            h.update(field.encode())

        for key in cellkeys:
            h.update(key.encode())

        request_hash = h.hexdigest()

        return request_hash

    def expect_subtask_execution(self, subtask: Subtask):
        """
        Expect the execution of 1 Subtask by tracking it in DynamoDB.
        A Subtask is executed either by a Lambda or AWS Batch.
        :param subtask: The expected Subtask to be executed.
        """
        subtask_to_dynamo_field_name = {
            Subtask.DRIVER: RequestTableField.EXPECTED_DRIVER_EXECUTIONS,
            Subtask.CONVERTER: RequestTableField.EXPECTED_CONVERTER_EXECUTIONS,
        }

        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            subtask_to_dynamo_field_name[subtask], 1)

    def complete_subtask_execution(self, subtask: Subtask):
        """
        Counts the completed execution of 1 Subtask in DynamoDB.
        A Subtask is executed either by a Lambda or AWS Batch.
        :param subtask: The executed Subtask.
        """
        subtask_to_dynamo_field_name = {
            Subtask.DRIVER: RequestTableField.COMPLETED_DRIVER_EXECUTIONS,
            Subtask.QUERY: RequestTableField.COMPLETED_QUERY_EXECUTIONS,
            Subtask.CONVERTER:
            RequestTableField.COMPLETED_CONVERTER_EXECUTIONS,
        }

        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            subtask_to_dynamo_field_name[subtask], 1)

    def lookup_cached_result(self) -> str:
        """
        Retrieves the S3 key of an existing matrix result that corresponds to this request's request hash.
        Returns "" if no such result exists
        :return: S3 key of cached result
        """
        results_bucket = S3Handler(os.environ['MATRIX_RESULTS_BUCKET'])
        objects = results_bucket.ls(f"{self.s3_results_prefix}/")

        if len(objects) > 0:
            return objects[0]['Key']
        return ""

    def is_request_ready_for_conversion(self) -> bool:
        """
        Checks whether the request has completed all queries
        and is ready for conversion
        :return: bool True if complete, else False
        """
        request_state = self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE, key=self.request_id)
        queries_complete = (
            request_state[RequestTableField.EXPECTED_QUERY_EXECUTIONS.value] ==
            request_state[RequestTableField.COMPLETED_QUERY_EXECUTIONS.value])
        return queries_complete

    def is_request_complete(self) -> bool:
        """
        Checks whether the request has completed.
        :return: bool True if complete, else False
        """
        results_bucket = S3Handler(os.environ['MATRIX_RESULTS_BUCKET'])
        return results_bucket.exists(self.s3_results_key)

    def complete_request(self, duration: float):
        """
        Log the completion of a matrix request in CloudWatch Metrics
        :param duration: The time in seconds the request took to complete
        """
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.CONVERSION_COMPLETION, metric_value=1)
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.REQUEST_COMPLETION, metric_value=1)
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.DURATION,
            metric_value=duration,
            metric_dimensions=[
                {
                    'Name': "Number of Bundles",
                    'Value': self.num_bundles_interval
                },
                {
                    'Name': "Output Format",
                    'Value': self.format
                },
            ])
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.DURATION,
            metric_value=duration,
            metric_dimensions=[
                {
                    'Name': "Number of Bundles",
                    'Value': self.num_bundles_interval
                },
            ])

    def log_error(self, message: str):
        """
        Logs the latest error this request reported overwriting the previously logged error.
        :param message: str The error message to log
        """
        logger.debug(message)
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.ERROR_MESSAGE, message)
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.REQUEST_ERROR, metric_value=1)

    def write_batch_job_id_to_db(self, batch_job_id: str):
        """
        Logs the batch job id for matrix conversion to state table
        """
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.BATCH_JOB_ID, batch_job_id)
コード例 #3
0
    def test_invalidate_cache_entries(self, mock_put_metric_data):
        """
        Setup:
        - Create four request ids mapping to two request hashes
        - Invalidate hash 1 (ids 1, 2) and id 3
        - Verify ids 1, 2 and 3 have been invalidated
        - Verify id 4 has not been invalidated
        """
        request_hash_1 = "test_hash_1"
        request_hash_2 = "test_hash_2"
        request_id_1 = "test_id_1"
        request_id_2 = "test_id_2"
        request_id_3 = "test_id_3"
        request_id_4 = "test_id_4"
        test_format = "test_format"
        test_content = "test_content"

        s3_key_1 = f"0/{request_hash_1}/{request_id_1}.{test_format}"
        s3_key_2 = f"0/{request_hash_1}/{request_id_2}.{test_format}"
        s3_key_3 = f"0/{request_hash_2}/{request_id_3}.{test_format}"
        s3_key_4 = f"0/{request_hash_2}/{request_id_4}.{test_format}"

        dynamo_handler = DynamoHandler()
        dynamo_handler.create_request_table_entry(request_id_1, test_format)
        dynamo_handler.create_request_table_entry(request_id_2, test_format)
        dynamo_handler.create_request_table_entry(request_id_3, test_format)
        dynamo_handler.create_request_table_entry(request_id_4, test_format)

        dynamo_handler.set_table_field_with_value(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_1,
            field_enum=RequestTableField.REQUEST_HASH,
            field_value=request_hash_1)
        dynamo_handler.set_table_field_with_value(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_2,
            field_enum=RequestTableField.REQUEST_HASH,
            field_value=request_hash_1)
        dynamo_handler.set_table_field_with_value(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_3,
            field_enum=RequestTableField.REQUEST_HASH,
            field_value=request_hash_2)
        dynamo_handler.set_table_field_with_value(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_4,
            field_enum=RequestTableField.REQUEST_HASH,
            field_value=request_hash_2)

        s3_results_bucket_handler = S3Handler(
            os.environ['MATRIX_RESULTS_BUCKET'])
        s3_results_bucket_handler.store_content_in_s3(s3_key_1, test_content)
        s3_results_bucket_handler.store_content_in_s3(s3_key_2, test_content)
        s3_results_bucket_handler.store_content_in_s3(s3_key_3, test_content)
        s3_results_bucket_handler.store_content_in_s3(s3_key_4, test_content)

        self.assertTrue(s3_results_bucket_handler.exists(s3_key_1))
        self.assertTrue(s3_results_bucket_handler.exists(s3_key_2))
        self.assertTrue(s3_results_bucket_handler.exists(s3_key_3))
        self.assertTrue(s3_results_bucket_handler.exists(s3_key_4))

        invalidate_cache_entries(request_ids=[request_id_3],
                                 request_hashes=[request_hash_1])

        error_1 = dynamo_handler.get_table_item(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_1)[RequestTableField.ERROR_MESSAGE.value]
        error_2 = dynamo_handler.get_table_item(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_2)[RequestTableField.ERROR_MESSAGE.value]
        error_3 = dynamo_handler.get_table_item(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_3)[RequestTableField.ERROR_MESSAGE.value]
        error_4 = dynamo_handler.get_table_item(
            table=DynamoTable.REQUEST_TABLE,
            key=request_id_4)[RequestTableField.ERROR_MESSAGE.value]

        self.assertFalse(s3_results_bucket_handler.exists(s3_key_1))
        self.assertFalse(s3_results_bucket_handler.exists(s3_key_2))
        self.assertFalse(s3_results_bucket_handler.exists(s3_key_3))
        self.assertTrue(s3_results_bucket_handler.exists(s3_key_4))

        self.assertNotEqual(error_1, 0)
        self.assertNotEqual(error_2, 0)
        self.assertNotEqual(error_3, 0)
        self.assertEqual(error_4, 0)
コード例 #4
0
class TestDynamoHandler(MatrixTestCaseUsingMockAWS):
    """
    Environment variables are set in tests/unit/__init__.py
    """
    def setUp(self):
        super(TestDynamoHandler, self).setUp()

        self.dynamo = boto3.resource("dynamodb", region_name=os.environ['AWS_DEFAULT_REGION'])
        self.data_version_table_name = os.environ['DYNAMO_DATA_VERSION_TABLE_NAME']
        self.request_table_name = os.environ['DYNAMO_REQUEST_TABLE_NAME']
        self.request_id = str(uuid.uuid4())
        self.data_version = 1
        self.format = "zarr"

        self.create_test_data_version_table()
        self.create_test_deployment_table()
        self.create_test_request_table()

        self.init_test_data_version_table()
        self.init_test_deployment_table()

        self.handler = DynamoHandler()

    def _get_data_version_table_response_and_entry(self):
        data_version_primary_key = self.handler._get_dynamo_table_primary_key_from_enum(DynamoTable.DATA_VERSION_TABLE)
        response = self.dynamo.batch_get_item(
            RequestItems={
                self.data_version_table_name: {
                    'Keys': [{data_version_primary_key: self.data_version}]
                }
            }
        )
        entry = response['Responses'][self.data_version_table_name][0]
        return response, entry

    def _get_request_table_response_and_entry(self):
        response = self.dynamo.batch_get_item(
            RequestItems={
                self.request_table_name: {
                    'Keys': [{'RequestId': self.request_id}]
                }
            }
        )
        entry = response['Responses'][self.request_table_name][0]
        return response, entry

    @mock.patch("matrix.common.v1_api_handler.V1ApiHandler.describe_filter")
    @mock.patch("matrix.common.date.get_datetime_now")
    def test_create_data_version_table_entry(self, mock_get_datetime_now, mock_describe_filter):
        stub_date = '2019-03-18T180907.136216Z'
        mock_get_datetime_now.return_value = stub_date

        stub_cell_counts = {
            'test_project_uuid_1': 10,
            'test_project_uuid_2': 100
        }
        mock_describe_filter.return_value = {
            'cell_counts': stub_cell_counts
        }

        self.handler.create_data_version_table_entry(self.data_version)
        response, entry = self._get_data_version_table_response_and_entry()

        metadata_schema_versions = {}
        for schema_name in SUPPORTED_METADATA_SCHEMA_VERSIONS:
            metadata_schema_versions[schema_name.value] = SUPPORTED_METADATA_SCHEMA_VERSIONS[schema_name]

        self.assertEqual(len(response['Responses'][self.data_version_table_name]), 1)

        self.assertTrue(all(field.value in entry for field in DataVersionTableField))
        self.assertEqual(entry[DataVersionTableField.DATA_VERSION.value], self.data_version)
        self.assertEqual(entry[DataVersionTableField.CREATION_DATE.value], stub_date)
        self.assertEqual(entry[DataVersionTableField.PROJECT_CELL_COUNTS.value], stub_cell_counts)
        self.assertEqual(entry[DataVersionTableField.METADATA_SCHEMA_VERSIONS.value], metadata_schema_versions)

    @mock.patch("matrix.common.date.get_datetime_now")
    def test_create_request_table_entry(self, mock_get_datetime_now):
        stub_date = '2019-03-18T180907.136216Z'
        mock_get_datetime_now.return_value = stub_date
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()

        self.assertEqual(len(response['Responses'][self.request_table_name]), 1)

        self.assertTrue(all(field.value in entry for field in RequestTableField))
        self.assertEqual(entry[RequestTableField.FORMAT.value], self.format)
        self.assertEqual(entry[RequestTableField.METADATA_FIELDS.value], DEFAULT_FIELDS)
        self.assertEqual(entry[RequestTableField.FEATURE.value], "gene")
        self.assertEqual(entry[RequestTableField.DATA_VERSION.value], 0)
        self.assertEqual(entry[RequestTableField.REQUEST_HASH.value], "N/A")
        self.assertEqual(entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1)
        self.assertEqual(entry[RequestTableField.EXPECTED_CONVERTER_EXECUTIONS.value], 1)
        self.assertEqual(entry[RequestTableField.CREATION_DATE.value], stub_date)

    def test_increment_table_field_request_table_path(self):
        self.handler.create_request_table_entry(self.request_id, self.format)

        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0)
        self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

        field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS
        self.handler.increment_table_field(DynamoTable.REQUEST_TABLE, self.request_id, field_enum, 5)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 5)
        self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

    def test_set_table_field_with_value(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A")

        field_enum = RequestTableField.BATCH_JOB_ID
        self.handler.set_table_field_with_value(DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123")
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "123-123")

    def test_increment_field(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0)
        self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

        key_dict = {"RequestId": self.request_id}
        field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS
        self.handler._increment_field(self.handler._get_dynamo_table_resource_from_enum(DynamoTable.REQUEST_TABLE),
                                      key_dict,
                                      field_enum,
                                      15)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15)
        self.assertEqual(entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

    def test_set_field(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A")

        key_dict = {"RequestId": self.request_id}
        field_enum = RequestTableField.BATCH_JOB_ID
        self.handler._set_field(self.handler._get_dynamo_table_resource_from_enum(DynamoTable.REQUEST_TABLE),
                                key_dict,
                                field_enum,
                                "123-123")
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "123-123")

    def test_get_request_table_entry(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)
        self.assertEqual(entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1)

        key_dict = {"RequestId": self.request_id}
        field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS
        self.handler._increment_field(self.handler._get_dynamo_table_resource_from_enum(DynamoTable.REQUEST_TABLE),
                                      key_dict,
                                      field_enum,
                                      15)
        entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)
        self.assertEqual(entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15)

    def test_get_table_item(self):
        self.assertRaises(MatrixException, self.handler.get_table_item,
                          DynamoTable.REQUEST_TABLE,
                          key=self.request_id)

        self.handler.create_request_table_entry(self.request_id, self.format)
        entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE, key=self.request_id)
        self.assertEqual(entry[RequestTableField.ROW_COUNT.value], 0)

    def test_filter_table_items(self):
        items = self.handler.filter_table_items(
            table=DynamoTable.REQUEST_TABLE,
            attrs={RequestTableField.REQUEST_HASH.value: "N/A"}
        )
        self.assertEqual(len(items), 0)

        self.handler.create_request_table_entry(self.request_id, self.format)
        self.handler.create_request_table_entry(str(uuid.uuid4()), "test_format")

        items = self.handler.filter_table_items(
            table=DynamoTable.REQUEST_TABLE,
            attrs={RequestTableField.REQUEST_HASH.value: "N/A"}
        )
        self.assertEqual(len(items), 2)

        items = self.handler.filter_table_items(
            table=DynamoTable.REQUEST_TABLE,
            attrs={RequestTableField.REQUEST_HASH.value: "N/A",
                   RequestTableField.FORMAT.value: self.format}
        )
        self.assertEqual(len(items), 1)
        self.assertEqual(items[0][RequestTableField.REQUEST_ID.value], self.request_id)
コード例 #5
0
class TestDynamoHandler(MatrixTestCaseUsingMockAWS):
    """
    Environment variables are set in tests/unit/__init__.py
    """
    def setUp(self):
        super(TestDynamoHandler, self).setUp()

        self.dynamo = boto3.resource(
            "dynamodb", region_name=os.environ['AWS_DEFAULT_REGION'])
        self.request_table_name = os.environ['DYNAMO_REQUEST_TABLE_NAME']
        self.request_id = str(uuid.uuid4())
        self.format = "zarr"

        self.create_test_request_table()

        self.handler = DynamoHandler()

    def _get_request_table_response_and_entry(self):
        response = self.dynamo.batch_get_item(RequestItems={
            self.request_table_name: {
                'Keys': [{
                    'RequestId': self.request_id
                }]
            }
        })
        entry = response['Responses'][self.request_table_name][0]
        return response, entry

    @mock.patch("matrix.common.date.get_datetime_now")
    def test_create_request_table_entry(self, mock_get_datetime_now):
        stub_date = '2019-03-18T180907.136216Z'
        mock_get_datetime_now.return_value = stub_date
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()

        self.assertEqual(len(response['Responses'][self.request_table_name]),
                         1)

        self.assertTrue(
            all(field.value in entry for field in RequestTableField))
        self.assertEqual(entry[RequestTableField.FORMAT.value], self.format)
        self.assertEqual(
            entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1)
        self.assertEqual(
            entry[RequestTableField.EXPECTED_CONVERTER_EXECUTIONS.value], 1)
        self.assertEqual(entry[RequestTableField.CREATION_DATE.value],
                         stub_date)

    def test_increment_table_field_request_table_path(self):
        self.handler.create_request_table_entry(self.request_id, self.format)

        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(
            entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0)
        self.assertEqual(
            entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

        field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS
        self.handler.increment_table_field(DynamoTable.REQUEST_TABLE,
                                           self.request_id, field_enum, 5)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(
            entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 5)
        self.assertEqual(
            entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

    def test_set_table_field_with_value(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A")

        field_enum = RequestTableField.BATCH_JOB_ID
        self.handler.set_table_field_with_value(DynamoTable.REQUEST_TABLE,
                                                self.request_id, field_enum,
                                                "123-123")
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value],
                         "123-123")

    def test_increment_field(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(
            entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 0)
        self.assertEqual(
            entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

        key_dict = {"RequestId": self.request_id}
        field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS
        self.handler._increment_field(self.handler._request_table, key_dict,
                                      field_enum, 15)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(
            entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15)
        self.assertEqual(
            entry[RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value], 0)

    def test_set_field(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value], "N/A")

        key_dict = {"RequestId": self.request_id}
        field_enum = RequestTableField.BATCH_JOB_ID
        self.handler._set_field(self.handler._request_table, key_dict,
                                field_enum, "123-123")
        response, entry = self._get_request_table_response_and_entry()
        self.assertEqual(entry[RequestTableField.BATCH_JOB_ID.value],
                         "123-123")

    def test_get_request_table_entry(self):
        self.handler.create_request_table_entry(self.request_id, self.format)
        entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                            request_id=self.request_id)
        self.assertEqual(
            entry[RequestTableField.EXPECTED_DRIVER_EXECUTIONS.value], 1)

        key_dict = {"RequestId": self.request_id}
        field_enum = RequestTableField.COMPLETED_DRIVER_EXECUTIONS
        self.handler._increment_field(self.handler._request_table, key_dict,
                                      field_enum, 15)
        entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                            request_id=self.request_id)
        self.assertEqual(
            entry[RequestTableField.COMPLETED_DRIVER_EXECUTIONS.value], 15)

    def test_get_table_item(self):
        self.assertRaises(MatrixException,
                          self.handler.get_table_item,
                          DynamoTable.REQUEST_TABLE,
                          request_id=self.request_id)

        self.handler.create_request_table_entry(self.request_id, self.format)
        entry = self.handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                            request_id=self.request_id)
        self.assertEqual(entry[RequestTableField.ROW_COUNT.value], 0)
コード例 #6
0
class RequestTracker:
    """
    Provides an interface for tracking a request's parameters and state.
    """
    def __init__(self, request_id: str):
        Logging.set_correlation_id(logger, request_id)

        self.request_id = request_id
        self._num_bundles = None
        self._format = None

        self.dynamo_handler = DynamoHandler()
        self.cloudwatch_handler = CloudwatchHandler()
        self.batch_handler = BatchHandler()

    @property
    def is_initialized(self) -> bool:
        try:
            self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                               request_id=self.request_id)
        except MatrixException:
            return False

        return True

    @property
    def num_bundles(self) -> int:
        """
        The number of bundles in the request.
        :return: int Number of bundles
        """
        if not self._num_bundles:
            self._num_bundles =\
                self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                                   request_id=self.request_id)[RequestTableField.NUM_BUNDLES.value]
        return self._num_bundles

    @property
    def num_bundles_interval(self) -> str:
        """
        Returns the interval string that num_bundles corresponds to.
        :return: the interval string e.g. "0-499"
        """
        interval_size = 500

        index = int(self.num_bundles / interval_size)
        return f"{index * interval_size}-{(index * interval_size) + interval_size - 1}"

    @property
    def format(self) -> str:
        """
        The request's user specified output file format of the resultant expression matrix.
        :return: str The file format (one of MatrixFormat)
        """
        if not self._format:
            self._format =\
                self.dynamo_handler.get_table_item(DynamoTable.REQUEST_TABLE,
                                                   request_id=self.request_id)[RequestTableField.FORMAT.value]
        return self._format

    @property
    def batch_job_id(self) -> str:
        """
        The batch job id for matrix conversion corresponding with a request.
        :return: str The batch job id
        """
        table_item = self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE, request_id=self.request_id)
        batch_job_id = table_item.get(RequestTableField.BATCH_JOB_ID.value)
        if not batch_job_id or batch_job_id == "N/A":
            return None
        else:
            return batch_job_id

    @property
    def batch_job_status(self) -> str:
        """
        The batch job status for matrix conversion corresponding with a request.
        :return: str The batch job status
        """
        status = None
        if self.batch_job_id:
            status = self.batch_handler.get_batch_job_status(self.batch_job_id)
        return status

    @property
    def creation_date(self) -> str:
        """
        The creation date of matrix service request.
        :return: str creation date
        """
        return self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE,
            request_id=self.request_id)[RequestTableField.CREATION_DATE.value]

    @property
    def timeout(self) -> bool:
        timeout = date.to_datetime(
            self.creation_date) < date.get_datetime_now() - timedelta(hours=12)
        if timeout:
            self.log_error(
                "This request has timed out after 12 hours."
                "Please try again by resubmitting the POST request.")
        return timeout

    @property
    def error(self) -> str:
        """
        The user-friendly message describing the latest error the request raised.
        :return: str The error message if one exists, else empty string
        """
        error = self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE,
            request_id=self.request_id)[RequestTableField.ERROR_MESSAGE.value]
        return error if error else ""

    def initialize_request(self, fmt: str) -> None:
        """Initialize the request id in the request state table. Put request metric to cloudwatch.
        :param format: Request output format for matrix conversion
        """
        self.dynamo_handler.create_request_table_entry(self.request_id, fmt)
        self.cloudwatch_handler.put_metric_data(metric_name=MetricName.REQUEST,
                                                metric_value=1)

    def expect_subtask_execution(self, subtask: Subtask):
        """
        Expect the execution of 1 Subtask by tracking it in DynamoDB.
        A Subtask is executed either by a Lambda or AWS Batch.
        :param subtask: The expected Subtask to be executed.
        """
        subtask_to_dynamo_field_name = {
            Subtask.DRIVER: RequestTableField.EXPECTED_DRIVER_EXECUTIONS,
            Subtask.CONVERTER: RequestTableField.EXPECTED_CONVERTER_EXECUTIONS,
        }

        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            subtask_to_dynamo_field_name[subtask], 1)

    def complete_subtask_execution(self, subtask: Subtask):
        """
        Counts the completed execution of 1 Subtask in DynamoDB.
        A Subtask is executed either by a Lambda or AWS Batch.
        :param subtask: The executed Subtask.
        """
        subtask_to_dynamo_field_name = {
            Subtask.DRIVER: RequestTableField.COMPLETED_DRIVER_EXECUTIONS,
            Subtask.QUERY: RequestTableField.COMPLETED_QUERY_EXECUTIONS,
            Subtask.CONVERTER:
            RequestTableField.COMPLETED_CONVERTER_EXECUTIONS,
        }

        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            subtask_to_dynamo_field_name[subtask], 1)

    def is_request_complete(self) -> bool:
        """
        Checks whether the request has completed,
        i.e. if all expected reducers and converters have completed.
        :return: bool True if complete, else False
        """
        request_state = self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE, request_id=self.request_id)
        queries_complete = (
            request_state[RequestTableField.EXPECTED_QUERY_EXECUTIONS.value] ==
            request_state[RequestTableField.COMPLETED_QUERY_EXECUTIONS.value])
        converter_complete = (request_state[
            RequestTableField.EXPECTED_CONVERTER_EXECUTIONS.
            value] == request_state[
                RequestTableField.COMPLETED_CONVERTER_EXECUTIONS.value])

        return queries_complete and converter_complete

    def is_request_ready_for_conversion(self) -> bool:
        """
        Checks whether the request has completed all queries
        and is ready for conversion
        :return: bool True if complete, else False
        """
        request_state = self.dynamo_handler.get_table_item(
            DynamoTable.REQUEST_TABLE, request_id=self.request_id)
        queries_complete = (
            request_state[RequestTableField.EXPECTED_QUERY_EXECUTIONS.value] ==
            request_state[RequestTableField.COMPLETED_QUERY_EXECUTIONS.value])
        return queries_complete

    def complete_request(self, duration: float):
        """
        Log the completion of a matrix request in CloudWatch Metrics
        :param duration: The time in seconds the request took to complete
        """
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.CONVERSION_COMPLETION, metric_value=1)
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.REQUEST_COMPLETION, metric_value=1)
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.DURATION,
            metric_value=duration,
            metric_dimensions=[
                {
                    'Name': "Number of Bundles",
                    'Value': self.num_bundles_interval
                },
                {
                    'Name': "Output Format",
                    'Value': self.format
                },
            ])
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.DURATION,
            metric_value=duration,
            metric_dimensions=[
                {
                    'Name': "Number of Bundles",
                    'Value': self.num_bundles_interval
                },
            ])

    def log_error(self, message: str):
        """
        Logs the latest error this request reported overwriting the previously logged error.
        :param message: str The error message to log
        """
        logger.debug(message)
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.ERROR_MESSAGE, message)
        self.cloudwatch_handler.put_metric_data(
            metric_name=MetricName.REQUEST_ERROR, metric_value=1)

    def write_batch_job_id_to_db(self, batch_job_id: str):
        """
        Logs the batch job id for matrix conversion to state table
        """
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.BATCH_JOB_ID, batch_job_id)
コード例 #7
0
class TestRequestTracker(MatrixTestCaseUsingMockAWS):
    @mock.patch("matrix.common.date.get_datetime_now")
    def setUp(self, mock_get_datetime_now):
        super(TestRequestTracker, self).setUp()
        self.stub_date = '2019-03-18T180907.136216Z'
        mock_get_datetime_now.return_value = self.stub_date

        self.request_id = str(uuid.uuid4())
        self.request_tracker = RequestTracker(self.request_id)
        self.dynamo_handler = DynamoHandler()

        self.create_test_request_table()

        self.dynamo_handler.create_request_table_entry(self.request_id,
                                                       "test_format")

    def test_format(self):
        self.assertEqual(self.request_tracker.format, "test_format")

    def test_batch_job_id(self):
        self.assertEqual(self.request_tracker.batch_job_id, None)

        field_enum = RequestTableField.BATCH_JOB_ID
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123")
        self.assertEqual(self.request_tracker.batch_job_id, "123-123")

    @mock.patch(
        "matrix.common.aws.batch_handler.BatchHandler.get_batch_job_status")
    def test_batch_job_status(self, mock_get_job_status):
        mock_get_job_status.return_value = "FAILED"
        field_enum = RequestTableField.BATCH_JOB_ID
        self.dynamo_handler.set_table_field_with_value(
            DynamoTable.REQUEST_TABLE, self.request_id, field_enum, "123-123")

        self.assertEqual(self.request_tracker.batch_job_status, "FAILED")

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.num_bundles",
        new_callable=mock.PropertyMock)
    def test_num_bundles_interval(self, mock_num_bundles):
        mock_num_bundles.return_value = 0
        self.assertEqual(self.request_tracker.num_bundles_interval, "0-499")

        mock_num_bundles.return_value = 1
        self.assertEqual(self.request_tracker.num_bundles_interval, "0-499")

        mock_num_bundles.return_value = 500
        self.assertEqual(self.request_tracker.num_bundles_interval, "500-999")

        mock_num_bundles.return_value = 1234
        self.assertEqual(self.request_tracker.num_bundles_interval,
                         "1000-1499")

    def test_creation_date(self):
        self.assertEqual(self.request_tracker.creation_date, self.stub_date)

    @mock.patch(
        "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data"
    )
    def test_error(self, mock_cw_put):
        self.assertEqual(self.request_tracker.error, "")

        self.request_tracker.log_error("test error")
        self.assertEqual(self.request_tracker.error, "test error")
        mock_cw_put.assert_called_once_with(
            metric_name=MetricName.REQUEST_ERROR, metric_value=1)

    @mock.patch(
        "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data"
    )
    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.create_request_table_entry"
    )
    def test_initialize_request(self, mock_create_request_table_entry,
                                mock_create_cw_metric):
        self.request_tracker.initialize_request("test_format")

        mock_create_request_table_entry.assert_called_once_with(
            self.request_id, "test_format")
        mock_create_cw_metric.assert_called_once()

    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field")
    def test_expect_subtask_execution(self, mock_increment_table_field):
        self.request_tracker.expect_subtask_execution(Subtask.DRIVER)

        mock_increment_table_field.assert_called_once_with(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.EXPECTED_DRIVER_EXECUTIONS, 1)

    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.increment_table_field")
    def test_complete_subtask_execution(self, mock_increment_table_field):
        self.request_tracker.complete_subtask_execution(Subtask.DRIVER)

        mock_increment_table_field.assert_called_once_with(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.COMPLETED_DRIVER_EXECUTIONS, 1)

    def test_is_request_complete(self):
        self.assertFalse(self.request_tracker.is_request_complete())

        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.COMPLETED_CONVERTER_EXECUTIONS, 1)
        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.COMPLETED_QUERY_EXECUTIONS, 3)
        self.assertTrue(self.request_tracker.is_request_complete())

    def test_is_request_ready_for_conversion(self):
        self.assertFalse(
            self.request_tracker.is_request_ready_for_conversion())
        self.dynamo_handler.increment_table_field(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.COMPLETED_QUERY_EXECUTIONS, 3)
        self.assertTrue(self.request_tracker.is_request_ready_for_conversion())

    @mock.patch(
        "matrix.common.aws.cloudwatch_handler.CloudwatchHandler.put_metric_data"
    )
    def test_complete_request(self, mock_cw_put):
        duration = 1

        self.request_tracker.complete_request(duration)

        expected_calls = [
            mock.call(metric_name=MetricName.CONVERSION_COMPLETION,
                      metric_value=1),
            mock.call(metric_name=MetricName.REQUEST_COMPLETION,
                      metric_value=1),
            mock.call(metric_name=MetricName.DURATION,
                      metric_value=duration,
                      metric_dimensions=[
                          {
                              'Name': "Number of Bundles",
                              'Value': mock.ANY
                          },
                          {
                              'Name': "Output Format",
                              'Value': mock.ANY
                          },
                      ]),
            mock.call(metric_name=MetricName.DURATION,
                      metric_value=duration,
                      metric_dimensions=[
                          {
                              'Name': "Number of Bundles",
                              'Value': mock.ANY
                          },
                      ])
        ]
        mock_cw_put.assert_has_calls(expected_calls)

    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.log_error")
    @mock.patch(
        "matrix.common.request.request_tracker.RequestTracker.creation_date",
        new_callable=mock.PropertyMock)
    def test_timeout(self, mock_creation_date, mock_log_error):
        # no timeout
        mock_creation_date.return_value = date.to_string(
            date.get_datetime_now() - timedelta(hours=11, minutes=59))
        self.assertFalse(self.request_tracker.timeout)

        # timeout
        mock_creation_date.return_value = date.to_string(
            date.get_datetime_now() - timedelta(hours=12, minutes=1))
        self.assertTrue(self.request_tracker.timeout)
        mock_log_error.assert_called_once()

    @mock.patch(
        "matrix.common.aws.dynamo_handler.DynamoHandler.set_table_field_with_value"
    )
    def test_write_batch_job_id_to_db(self, mock_set_table_field_with_value):
        self.request_tracker.write_batch_job_id_to_db("123-123")
        mock_set_table_field_with_value.assert_called_once_with(
            DynamoTable.REQUEST_TABLE, self.request_id,
            RequestTableField.BATCH_JOB_ID, "123-123")