Beispiel #1
0
    def test_sensor(self, mock_describe_job, hook_init, mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TUNING_INPROGRESS_RESPONSE,
            DESCRIBE_TUNING_STOPPING_RESPONSE,
            DESCRIBE_TUNING_COMPLETED_RESPONSE
        ]
        sensor = SageMakerTuningSensor(
            task_id='test_task',
            poke_interval=2,
            aws_conn_id='aws_test',
            job_name='test_job_name'
        )

        sensor.execute(None)

        # make sure we called 3 times(terminated when its completed)
        self.assertEqual(mock_describe_job.call_count, 3)

        # make sure the hook was initialized with the specific params
        calls = [
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test'),
        ]
        hook_init.assert_has_calls(calls)
Beispiel #2
0
 def test_sensor_with_failure(self, mock_describe_job, mock_client):
     mock_describe_job.side_effect = [DESCRIBE_TUNING_FAILED_RESPONSE]
     sensor = SageMakerTuningSensor(
         task_id='test_task', poke_interval=2, aws_conn_id='aws_test', job_name='test_job_name'
     )
     with pytest.raises(AirflowException):
         sensor.execute(None)
     mock_describe_job.assert_called_once_with('test_job_name')