예제 #1
0
    def test_uses_overrides_s3_location(
        self,
        mock_from_s3: MagicMock,
        mock_launch_cluster: MagicMock,
        mock_retrieve_secrets: MagicMock,
    ):
        if "EMR_LAUNCHER_CONFIG_DIR" in os.environ:
            del os.environ["EMR_LAUNCHER_CONFIG_DIR"]

        calls = [
            call(bucket="Test_S3_Bucket", key=f"Test_S3_Folder/cluster.yaml"),
            call(bucket="Test_S3_Bucket", key=f"Test_S3_Folder/configurations.yaml"),
            call(bucket="Test_S3_Bucket", key=f"Test_S3_Folder/instances.yaml"),
            call(bucket="Test_S3_Bucket", key=f"Test_S3_Folder/steps.yaml"),
        ]
        mock_retrieve_secrets.side_effect = mock_retrieve_secrets_side_effect

        s3_overrides = {
            "emr_launcher_config_s3_bucket": "Test_S3_Bucket",
            "emr_launcher_config_s3_folder": "Test_S3_Folder",
        }

        handler({"s3_overrides": s3_overrides})
        mock_launch_cluster.assert_called_once()
        mock_from_s3.assert_has_calls(calls, any_order=True)
예제 #2
0
    def test_handlers_same_result(
        self,
        mock_tag_cluster: MagicMock,
        mock_launch_cluster: MagicMock,
        mock_retrieve_secrets: MagicMock,
    ):

        mock_retrieve_secrets.side_effect = mock_retrieve_secrets_side_effect
        with pytest.raises(ValueError):
            handler({"correlation_id": "test", "s3_prefix": "test"})

        assert mock_launch_cluster.call_count == 0

        handler(
            {
                "additional_step_args": {
                    "submit-job": [
                        "--correlation_id",
                        "test",
                        "--s3_prefix",
                        "test",
                        "--snapshot_type",
                        "NOT_SET",
                        "--export_date",
                        "NOT_SET",
                    ]
                }
            }
        )

        assert mock_launch_cluster.call_count == 1
예제 #3
0
    def test_launches_correct_cluster_with_overrides(
            self, mock_launch_cluster: MagicMock,
            mock_retrieve_secrets: MagicMock):
        mock_retrieve_secrets.side_effect = mock_retrieve_secrets_side_effect

        overrides = {
            "Name": "Test_Name",
            "Applications": [{
                "Name": "Spark"
            }],
            "Instances": {
                "Ec2SubnetId": "Test_Subnet_Id"
            },
        }

        expected = replace_secrets(get_default_config())
        expected["Name"] = overrides["Name"]
        expected["Applications"] = overrides["Applications"]
        expected["Instances"]["Ec2SubnetId"] = overrides["Instances"][
            "Ec2SubnetId"]

        handler({"overrides": overrides})

        mock_launch_cluster.assert_called_once()
        assert call(expected) == mock_launch_cluster.call_args_list[0]
예제 #4
0
    def test_handlers_same_result(
        self,
        mock_tag_cluster: MagicMock,
        mock_launch_cluster: MagicMock,
        mock_retrieve_secrets: MagicMock,
    ):

        mock_retrieve_secrets.side_effect = mock_retrieve_secrets_side_effect
        handler({"correlation_id": "test", "s3_prefix": "test"})

        assert mock_launch_cluster.call_count == 1
        old_handler_call = mock_launch_cluster.call_args_list[0]

        handler({
            "additional_step_args": {
                "submit-job": [
                    "--correlation_id",
                    "test",
                    "--s3_prefix",
                    "test",
                    "--snapshot_type",
                    "NOT_SET",
                    "--export_date",
                    "NOT_SET",
                ]
            }
        })

        mock_tag_cluster.assert_called_once()

        assert mock_launch_cluster.call_count == 2
        new_handler_call = mock_launch_cluster.call_args_list[1]

        assert old_handler_call == new_handler_call
예제 #5
0
    def test_launches_correct_cluster(self, mock_launch_cluster: MagicMock,
                                      mock_retrieve_secrets: MagicMock):
        mock_retrieve_secrets.side_effect = mock_retrieve_secrets_side_effect

        expected = replace_secrets(get_default_config())

        handler()

        mock_launch_cluster.assert_called_once()
        assert call(expected) == mock_launch_cluster.call_args_list[0]
예제 #6
0
    def test_launches_correct_cluster_with_extend(
        self, mock_launch_cluster: MagicMock, mock_retrieve_secrets: MagicMock
    ):
        mock_retrieve_secrets.side_effect = mock_retrieve_secrets_side_effect

        test_extend_fleet = {"InstanceFleetType": "CORE", "Name": "TEST"}

        extend = {"Instances.InstanceFleets": [test_extend_fleet]}

        expected = replace_secrets(get_default_config())
        expected["Instances"]["InstanceFleets"].append(test_extend_fleet)

        handler({"extend": extend})

        mock_launch_cluster.assert_called_once()
        assert call(expected) == mock_launch_cluster.call_args_list[0]
예제 #7
0
    def test_uses_default_s3_location(
        self,
        mock_from_s3: MagicMock,
        mock_launch_cluster: MagicMock,
        mock_retrieve_secrets: MagicMock,
    ):
        if "EMR_LAUNCHER_CONFIG_DIR" in os.environ:
            del os.environ["EMR_LAUNCHER_CONFIG_DIR"]
        os.environ["EMR_LAUNCHER_CONFIG_S3_FOLDER"] = "s3_folder"
        os.environ["EMR_LAUNCHER_CONFIG_S3_BUCKET"] = "s3_bucket"

        calls = [
            call(bucket="s3_bucket", key=f"s3_folder/cluster.yaml"),
            call(bucket="s3_bucket", key=f"s3_folder/configurations.yaml"),
            call(bucket="s3_bucket", key=f"s3_folder/instances.yaml"),
            call(bucket="s3_bucket", key=f"s3_folder/steps.yaml"),
        ]
        mock_retrieve_secrets.side_effect = mock_retrieve_secrets_side_effect

        handler()
        mock_launch_cluster.assert_called_once()
        mock_from_s3.assert_has_calls(calls, any_order=True)
예제 #8
0
from emr_launcher.logger import configure_log
from emr_launcher.handler import handler

logger = configure_log()
try:
    handler()
except Exception as e:
    logger.error(e)