Example #1
0
 def test_training_with_logs(self, mock_describe, mock_client,
                             mock_log_client, mock_check_training):
     mock_check_training.return_value = True
     mock_describe.side_effect = [
         (LogState.WAIT_IN_PROGRESS, DESCRIBE_TRAINING_INPROGRESS_RETURN,
          0),
         (LogState.JOB_COMPLETE, DESCRIBE_TRAINING_STOPPING_RETURN, 0),
         (LogState.COMPLETE, DESCRIBE_TRAINING_COMPLETED_RETURN, 0),
     ]
     mock_session = mock.Mock()
     mock_log_session = mock.Mock()
     attrs = {
         'create_training_job.return_value':
         test_arn_return,
         'describe_training_job.return_value':
         DESCRIBE_TRAINING_COMPLETED_RETURN,
     }
     log_attrs = {
         'describe_log_streams.side_effect': LIFECYCLE_LOG_STREAMS,
         'get_log_events.side_effect': STREAM_LOG_EVENTS,
     }
     mock_session.configure_mock(**attrs)
     mock_log_session.configure_mock(**log_attrs)
     mock_client.return_value = mock_session
     mock_log_client.return_value = mock_log_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params,
                              wait_for_completion=True,
                              print_log=True,
                              check_interval=1)
     self.assertEqual(mock_describe.call_count, 3)
     self.assertEqual(mock_session.describe_training_job.call_count, 1)
Example #2
0
 def test_training_throws_error_when_failed_with_wait(
         self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {
         'create_training_job.return_value':
         test_arn_return,
         'describe_training_job.side_effect': [
             DESCRIBE_TRAINING_INPROGRESS_RETURN,
             DESCRIBE_TRAINING_STOPPING_RETURN,
             DESCRIBE_TRAINING_FAILED_RETURN,
             DESCRIBE_TRAINING_COMPLETED_RETURN,
         ],
     }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     with pytest.raises(AirflowException):
         hook.create_training_job(
             create_training_params,
             wait_for_completion=True,
             print_log=False,
             check_interval=1,
         )
     assert mock_session.describe_training_job.call_count == 3
Example #3
0
 def test_training_ends_with_wait(self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_training_job.return_value':
              test_arn_return,
              'describe_training_job.side_effect':
              [DESCRIBE_TRAINING_INPROGRESS_RETURN,
               DESCRIBE_TRAINING_STOPPING_RETURN,
               DESCRIBE_TRAINING_COMPELETED_RETURN,
               DESCRIBE_TRAINING_COMPELETED_RETURN]
              }
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id_1')
     hook.create_training_job(create_training_params, wait_for_completion=True,
                              print_log=False, check_interval=1)
     self.assertEqual(mock_session.describe_training_job.call_count, 4)
Example #4
0
 def test_create_training_job(self, mock_client, mock_check_training):
     mock_check_training.return_value = True
     mock_session = mock.Mock()
     attrs = {'create_training_job.return_value': test_arn_return}
     mock_session.configure_mock(**attrs)
     mock_client.return_value = mock_session
     hook = SageMakerHook(aws_conn_id='sagemaker_test_conn_id')
     response = hook.create_training_job(
         create_training_params, wait_for_completion=False, print_log=False
     )
     mock_session.create_training_job.assert_called_once_with(**create_training_params)
     assert response == test_arn_return