def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, mock_check_training): mock_check_training.return_value = True mock_describe.side_effect = [ (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN, 0), (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0), (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RETURN, 0), ] mock_session = mock.Mock() mock_log_session = mock.Mock() attrs = { 'create_training_job.return_value': test_arn_return, 'describe_training_job.return_value': DESCRIBE_TRAINING_COMPLETED_RETURN, } log_attrs = { 'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS, 'get_log_events.side_effect': STREAM_LOG_EVENTS, } mock_session.configure_mock(**attrs) mock_log_session.configure_mock(**log_attrs) mock_client.return_value = mock_session mock_log_client.return_value = mock_log_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True, print_log=True, check_interval=1) self.assertEqual(mock_describe.call_count, 3) self.assertEqual(mock_session.describe_training_job.call_count, 1)
def test_training_throws_error_when_failed_with_wait( self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = { 'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': [ DESCRIBE_TRAINING_INPROGRESS_RETURN, DESCRIBE_TRAINING_STOPPING_RETURN, DESCRIBE_TRAINING_FAILED_RETURN, DESCRIBE_TRAINING_COMPLETED_RETURN, ], } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') with pytest.raises(AirflowException): hook.create_training_job( create_training_params, wait_for_completion=True, print_log=False, check_interval=1, ) assert mock_session.describe_training_job.call_count == 3
def test_training_ends_with_wait(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return, 'describe_training_job.side_effect': [DESCRIBE_TRAINING_INPROGRESS_RETURN, DESCRIBE_TRAINING_STOPPING_RETURN, DESCRIBE_TRAINING_COMPELETED_RETURN, DESCRIBE_TRAINING_COMPELETED_RETURN] } mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1') hook.create_training_job(create_training_params, wait_for_completion=True, print_log=False, check_interval=1) self.assertEqual(mock_session.describe_training_job.call_count, 4)
def test_create_training_job(self, mock_client, mock_check_training): mock_check_training.return_value = True mock_session = mock.Mock() attrs = {'create_training_job.return_value': test_arn_return} mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id') response = hook.create_training_job( create_training_params, wait_for_completion=False, print_log=False ) mock_session.create_training_job.assert_called_once_with(**create_training_params) assert response == test_arn_return