def test_sensor(self, mock_describe_job, hook_init, mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TRAINING_INPROGRESS_RESPONSE,
            DESCRIBE_TRAINING_STOPPING_RESPONSE,
            DESCRIBE_TRAINING_COMPELETED_RESPONSE
        ]
        sensor = SageMakerTrainingSensor(
            task_id='test_task',
            poke_interval=2,
            aws_conn_id='aws_test',
            job_name='test_job_name',
            print_log=False
        )

        sensor.execute(None)

        # make sure we called 3 times(terminated when its compeleted)
        self.assertEqual(mock_describe_job.call_count, 3)

        # make sure the hook was initialized with the specific params
        calls = [
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test')
        ]
        hook_init.assert_has_calls(calls)
    def test_sensor_with_log(self, mock_describe_job, mock_describe_job_with_log,
                             hook_init, mock_log_client, mock_client):
        hook_init.return_value = None

        mock_describe_job.return_value = DESCRIBE_TRAINING_COMPELETED_RESPONSE
        mock_describe_job_with_log.side_effect = [
            (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RESPONSE, 0),
            (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RESPONSE, 0),
            (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RESPONSE, 0)
        ]
        sensor = SageMakerTrainingSensor(
            task_id='test_task',
            poke_interval=2,
            aws_conn_id='aws_test',
            job_name='test_job_name',
            print_log=True
        )

        sensor.execute(None)

        self.assertEqual(mock_describe_job_with_log.call_count, 3)
        self.assertEqual(mock_describe_job.call_count, 1)

        calls = [
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test'),
            mock.call(aws_conn_id='aws_test')
        ]
        hook_init.assert_has_calls(calls)
Exemplo n.º 3
0
 def test_raises_errors_failed_state(self, mock_describe_job, mock_client):
     mock_describe_job.side_effect = [DESCRIBE_TRAINING_FAILED_RETURN]
     sensor = SageMakerTrainingSensor(task_id='test_task',
                                      poke_interval=2,
                                      aws_conn_id='aws_test',
                                      job_name='test_job_name')
     self.assertRaises(AirflowException, sensor.execute, None)
     mock_describe_job.assert_called_once_with('test_job_name')
    def test_sensor_with_failure(self, mock_describe_job, hook_init,
                                 mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [DESCRIBE_TRAINING_FAILED_RESPONSE]
        sensor = SageMakerTrainingSensor(task_id='test_task',
                                         poke_interval=2,
                                         aws_conn_id='aws_test',
                                         job_name='test_job_name',
                                         print_log=False)
        self.assertRaises(AirflowException, sensor.execute, None)
        mock_describe_job.assert_called_once_with('test_job_name')
    def test_sensor(self, mock_describe_job, hook_init, mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TRAINING_INPROGRESS_RESPONSE,
            DESCRIBE_TRAINING_STOPPING_RESPONSE,
            DESCRIBE_TRAINING_COMPELETED_RESPONSE
        ]
        sensor = SageMakerTrainingSensor(
            task_id='test_task',
            poke_interval=2,
            aws_conn_id='aws_test',
            job_name='test_job_name',
            print_log=False
        )

        sensor.execute(None)

        # make sure we called 3 times(terminated when its compeleted)
        self.assertEqual(mock_describe_job.call_count, 3)

        # make sure the hook was initialized with the specific params
        hook_init.assert_called_with(aws_conn_id='aws_test')
    def test_sensor_with_log(self, mock_describe_job, mock_describe_job_with_log,
                             hook_init, mock_log_client, mock_client):
        hook_init.return_value = None

        mock_describe_job.return_value = DESCRIBE_TRAINING_COMPELETED_RESPONSE
        mock_describe_job_with_log.side_effect = [
            (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RESPONSE, 0),
            (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RESPONSE, 0),
            (LogState.COMPLETE, DESCRIBE_TRAINING_COMPELETED_RESPONSE, 0)
        ]
        sensor = SageMakerTrainingSensor(
            task_id='test_task',
            poke_interval=2,
            aws_conn_id='aws_test',
            job_name='test_job_name',
            print_log=True
        )

        sensor.execute(None)

        self.assertEqual(mock_describe_job_with_log.call_count, 3)
        self.assertEqual(mock_describe_job.call_count, 1)

        hook_init.assert_called_with(aws_conn_id='aws_test')
Exemplo n.º 7
0
    def test_calls_until_a_terminal_state(self, mock_describe_job, hook_init,
                                          mock_client):
        hook_init.return_value = None

        mock_describe_job.side_effect = [
            DESCRIBE_TRAINING_INPROGRESS_RETURN,
            DESCRIBE_TRAINING_STOPPING_RETURN,
            DESCRIBE_TRAINING_STOPPED_RETURN,
            DESCRIBE_TRAINING_COMPELETED_RETURN
        ]
        sensor = SageMakerTrainingSensor(task_id='test_task',
                                         poke_interval=2,
                                         aws_conn_id='aws_test',
                                         job_name='test_job_name',
                                         region_name='us-east-1')

        sensor.execute(None)

        # make sure we called 4 times(terminated when its compeleted)
        self.assertEqual(mock_describe_job.call_count, 4)

        # make sure the hook was initialized with the specific params
        hook_init.assert_called_with(aws_conn_id='aws_test',
                                     region_name='us-east-1')