def test_do_sets_name(self):
        named_spec = SageMakerProcessSpec(self.REQUIRED_ARGS +
                                          ["--job_name", "job-name"])
        unnamed_spec = SageMakerProcessSpec(self.REQUIRED_ARGS)
        self.component.Do(named_spec)
        self.assertEqual("job-name", self.component._processing_job_name)

        with patch(
                "process.src.sagemaker_process_component.SageMakerComponent._generate_unique_timestamped_id",
                MagicMock(return_value="unique"),
        ):
            self.component.Do(unnamed_spec)
            self.assertEqual("unique", self.component._processing_job_name)
    def test_no_defined_image(self):
        # Pass the image to pass the parser
        no_image_args = self.REQUIRED_ARGS.copy()
        image_index = no_image_args.index("--image")
        # Cut out --image and it's associated value
        no_image_args = no_image_args[:image_index] + no_image_args[image_index + 2 :]

        with self.assertRaises(SystemExit):
            SageMakerProcessSpec(no_image_args)
    def test_environment_variables(self):
        env_vars = {"key1": "val1", "key2": "val2"}

        environment_args = SageMakerProcessSpec(
            self.REQUIRED_ARGS +
            ["--environment", json.dumps(env_vars)])
        response = self.component._create_job_request(environment_args.inputs,
                                                      environment_args.outputs)

        self.assertEqual(response["Environment"], env_vars)
    def test_after_job_completed(self):
        spec = SageMakerProcessSpec(self.REQUIRED_ARGS)

        mock_out = {"out1": "val1", "out2": "val2"}
        self.component._get_job_outputs = MagicMock(return_value=mock_out)

        self.component._after_job_complete({}, {}, spec.inputs, spec.outputs)

        self.assertEqual(spec.outputs.job_name, "test-job")
        self.assertEqual(
            spec.outputs.output_artifacts, {"out1": "val1", "out2": "val2"}
        )
    def test_create_process_job(self):
        spec = SageMakerProcessSpec(self.REQUIRED_ARGS)
        request = self.component._create_job_request(spec.inputs, spec.outputs)

        self.assertEqual(
            request,
            {
                "ProcessingJobName":
                "test-job",
                "ProcessingInputs": [{
                    "InputName": "dataset-input",
                    "S3Input": {
                        "S3Uri": "s3://my-bucket/dataset.csv",
                        "LocalPath": "/opt/ml/processing/input",
                        "S3DataType": "S3Prefix",
                        "S3InputMode": "File",
                    },
                }],
                "ProcessingOutputConfig": {
                    "Outputs": [{
                        "OutputName": "training-outputs",
                        "S3Output": {
                            "S3Uri": "s3://my-bucket/outputs/train.csv",
                            "LocalPath": "/opt/ml/processing/output/train",
                            "S3UploadMode": "Continuous",
                        },
                    }]
                },
                "RoleArn":
                "arn:aws:iam::123456789012:user/Development/product_1234/*",
                "ProcessingResources": {
                    "ClusterConfig": {
                        "InstanceType": "ml.m4.xlarge",
                        "InstanceCount": 1,
                        "VolumeSizeInGB": 30,
                    }
                },
                "NetworkConfig": {
                    "EnableInterContainerTrafficEncryption": False,
                    "EnableNetworkIsolation": True,
                },
                "StoppingCondition": {
                    "MaxRuntimeInSeconds": 86400
                },
                "AppSpecification": {
                    "ImageUri": "test-image"
                },
                "Environment": {},
                "Tags": [],
            },
        )
    def test_container_entrypoint(self):
        entrypoint, arguments = ["/bin/bash"], ["arg1", "arg2"]

        container_args = SageMakerProcessSpec(self.REQUIRED_ARGS + [
            "--container_entrypoint",
            json.dumps(entrypoint),
            "--container_arguments",
            json.dumps(arguments),
        ])
        response = self.component._create_job_request(container_args.inputs,
                                                      container_args.outputs)

        self.assertEqual(response["AppSpecification"]["ContainerEntrypoint"],
                         entrypoint)
        self.assertEqual(response["AppSpecification"]["ContainerArguments"],
                         arguments)
    def test_vpc_configuration(self):
        required_vpc_args = SageMakerProcessSpec(
            self.REQUIRED_ARGS
            + [
                "--vpc_security_group_ids",
                "sg1,sg2",
                "--vpc_subnets",
                "subnet1,subnet2",
            ]
        )
        response = self.component._create_job_request(
            required_vpc_args.inputs, required_vpc_args.outputs
        )

        self.assertIn("VpcConfig", response["NetworkConfig"])
        self.assertIn("sg1", response["NetworkConfig"]["VpcConfig"]["SecurityGroupIds"])
        self.assertIn("sg2", response["NetworkConfig"]["VpcConfig"]["SecurityGroupIds"])
        self.assertIn("subnet1", response["NetworkConfig"]["VpcConfig"]["Subnets"])
        self.assertIn("subnet2", response["NetworkConfig"]["VpcConfig"]["Subnets"])
 def test_minimum_required_args(self):
     # Will raise if the inputs are incorrect
     spec = SageMakerProcessSpec(self.REQUIRED_ARGS)
        logging.info(
            "CloudWatch logs: https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix"
            .format(inputs.region, inputs.region, self._processing_job_name))

    def _print_logs_for_job(self):
        self._print_cloudwatch_logs("/aws/sagemaker/ProcessingJobs",
                                    self._processing_job_name)

    def _get_job_outputs(self):
        """Map the S3 outputs of a processing job to a dictionary object.

        Returns:
            dict: A dictionary of output S3 URIs.
        """
        response = self._sm_client.describe_processing_job(
            ProcessingJobName=self._processing_job_name)
        outputs = {}
        for output in response["ProcessingOutputConfig"]["Outputs"]:
            outputs[output["OutputName"]] = output["S3Output"]["S3Uri"]

        return outputs


if __name__ == "__main__":
    import sys

    spec = SageMakerProcessSpec(sys.argv[1:])

    component = SageMakerProcessComponent()
    component.Do(spec)