Beispiel #1
0
    def test_main(self):
        # Mock out all of utils except parser
        train._utils = MagicMock()
        train._utils.add_default_client_arguments = _utils.add_default_client_arguments

        # Set some static returns
        train._utils.create_training_job.return_value = 'job-name'
        train._utils.get_image_from_job.return_value = 'training-image'
        train._utils.get_model_artifacts_from_job.return_value = 'model-artifacts'

        with patch('builtins.open', mock_open()) as file_open:
            train.main(required_args)

        # Check if correct requests were created and triggered
        train._utils.create_training_job.assert_called()
        train._utils.wait_for_training_job.assert_called()

        # Check the file outputs
        file_open.assert_has_calls([
            call('/tmp/model_artifact_url.txt', 'w'),
            call('/tmp/job_name.txt', 'w'),
            call('/tmp/training_image.txt', 'w')
        ],
                                   any_order=True)

        file_open().write.assert_has_calls(
            [
                call('model-artifacts'),
                call('job-name'),
                call('training-image'),
            ],
            any_order=False)  # Must be in the same order as called
Beispiel #2
0
  def test_main_assumes_role(self):
    # Mock out all of utils except parser
    train._utils = MagicMock()
    train._utils.add_default_client_arguments = _utils.add_default_client_arguments

    # Set some static returns
    train._utils.create_training_job.return_value = 'job-name'
    train._utils.get_image_from_job.return_value = 'training-image'
    train._utils.get_model_artifacts_from_job.return_value = 'model-artifacts'

    assume_role_args = required_args + ['--assume_role', 'my-role']

    train.main(assume_role_args)

    train._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
Beispiel #3
0
  def test_main(self):
    # Mock out all of utils except parser
    train._utils = MagicMock()
    train._utils.add_default_client_arguments = _utils.add_default_client_arguments

    # Set some static returns
    train._utils.create_training_job.return_value = 'job-name'
    train._utils.get_image_from_job.return_value = 'training-image'
    train._utils.get_model_artifacts_from_job.return_value = 'model-artifacts'

    train.main(required_args)

    # Check if correct requests were created and triggered
    train._utils.create_training_job.assert_called()
    train._utils.wait_for_training_job.assert_called()
    train._utils.print_logs_for_job.assert_called()

    # Check the file outputs
    train._utils.write_output.assert_has_calls([
      call('/tmp/model_artifact_url_output_path', 'model-artifacts'),
      call('/tmp/job_name_output_path', 'job-name'),
      call('/tmp/training_image_output_path', 'training-image')
    ])