Пример #1
0
    def test_sensor(self, mock_describe_job, hook_init, mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TRAINING_INPROGRESS_RESPONSE,
            DESCRIBE_TRAINING_STOPPING_RESPONSE,
            DESCRIBE_TRAINING_COMPELETED_RESPONSE
        ]
        sensor = SageMakerTrainingSensor(task_id='test_task',
                                         poke_interval=2,
                                         aws_conn_id='aws_test',
                                         job_name='test_job_name',
                                         print_log=False)

        sensor.execute(None)

        # make sure we called 3 times(terminated when its compeleted)
        self.assertEqual(mock_describe_job.call_count, 3)

        # make sure the hook was initialized with the specific params
        calls = [
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test')
        ]
        hook_init.assert_has_calls(calls)
    def test_sensor_with_log(self, mock_describe_job,
                             mock_describe_job_with_log, hook_init,
                             mock_log_client, mock_client):
        hook_init.return_value = None

        mock_describe_job.return_value = DESCRIBE_TRAINING_COMPLETED_RESPONSE
        mock_describe_job_with_log.side_effect = [
            (LogState.WAIT_IN_PROGRESS,
             DESCRIBE_TRAINING_INPROGRESS_RESPONSE, 0),
            (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RESPONSE, 0),
            (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RESPONSE, 0)
        ]
        sensor = SageMakerTrainingSensor(task_id='test_task',
                                         poke_interval=2,
                                         aws_conn_id='aws_test',
                                         job_name='test_job_name',
                                         print_log=True)

        sensor.execute(None)

        self.assertEqual(mock_describe_job_with_log.call_count, 3)
        self.assertEqual(mock_describe_job.call_count, 1)

        calls = [mock.call(aws_conn_id='aws_test')]
        hook_init.assert_has_calls(calls)
Пример #3
0
    def test_sensor_with_failure(self, mock_describe_job, hook_init, mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [DESCRIBE_TRAINING_FAILED_RESPONSE]
        sensor = SageMakerTrainingSensor(
            task_id='test_task',
            poke_interval=2,
            aws_conn_id='aws_test',
            job_name='test_job_name',
            print_log=False,
        )
        with pytest.raises(AirflowException):
            sensor.execute(None)
        mock_describe_job.assert_called_once_with('test_job_name')