Ejemplo n.º 1
0
    def test_sensor(self, mock_describe_job, hook_init, mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TUNING_INPROGRESS_RESPONSE,
            DESCRIBE_TUNING_STOPPING_RESPONSE,
            DESCRIBE_TUNING_COMPELETED_RESPONSE
        ]
        sensor = SageMakerTuningSensor(task_id='test_task',
                                       poke_interval=2,
                                       aws_conn_id='aws_test',
                                       job_name='test_job_name')

        sensor.execute(None)

        # make sure we called 3 times(terminated when its compeleted)
        self.assertEqual(mock_describe_job.call_count, 3)

        # make sure the hook was initialized with the specific params
        calls = [
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test'),
        ]
        hook_init.assert_has_calls(calls)
    def test_calls_until_a_terminal_state(self,
                                          mock_describe_job, hook_init, mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TUNING_INPROGRESS_RETURN,
            DESCRIBE_TUNING_STOPPING_RETURN,
            DESCRIBE_TUNING_STOPPED_RETURN,
            DESCRIBE_TUNING_COMPELETED_RETURN
        ]
        sensor = SageMakerTuningSensor(
            task_id='test_task',
            poke_interval=2,
            aws_conn_id='aws_test',
            job_name='test_job_name',
            region_name='us-east-1'
        )

        sensor.execute(None)

        # make sure we called 4 times(terminated when its compeleted)
        self.assertEqual(mock_describe_job.call_count, 4)

        # make sure the hook was initialized with the specific params
        hook_init.assert_called_with(aws_conn_id='aws_test',
                                     region_name='us-east-1')
Ejemplo n.º 3
0
 def test_sensor_with_failure(self, mock_describe_job, mock_client):
     mock_describe_job.side_effect = [DESCRIBE_TUNING_FAILED_RESPONSE]
     sensor = SageMakerTuningSensor(task_id='test_task',
                                    poke_interval=2,
                                    aws_conn_id='aws_test',
                                    job_name='test_job_name')
     self.assertRaises(AirflowException, sensor.execute, None)
     mock_describe_job.assert_called_once_with('test_job_name')
Ejemplo n.º 4
0
    def test_calls_until_a_terminal_state(self, mock_describe_job, hook_init,
                                          mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TUNING_INPROGRESS_RETURN, DESCRIBE_TUNING_STOPPING_RETURN,
            DESCRIBE_TUNING_STOPPED_RETURN, DESCRIBE_TUNING_COMPELETED_RETURN
        ]
        sensor = SageMakerTuningSensor(task_id='test_task',
                                       poke_interval=2,
                                       aws_conn_id='aws_test',
                                       job_name='test_job_name',
                                       region_name='us-east-1')

        sensor.execute(None)

        # make sure we called 4 times(terminated when its compeleted)
        self.assertEqual(mock_describe_job.call_count, 4)

        # make sure the hook was initialized with the specific params
        hook_init.assert_called_with(aws_conn_id='aws_test',
                                     region_name='us-east-1')