def test_no_data_input(container, aws_session, creation_kwargs,
                       input_data_config):
    input_data_config.append({
        # this channel won't match any data
        "channelName": "no-data",
        "dataSource": {
            "s3DataSource": {
                "s3Uri": "s3://input_bucket/irrelevant"
            }
        },
    })
    creation_kwargs.update({"inputDataConfig": input_data_config})
    no_data_found = "No data found for channel 'no-data'"
    with pytest.raises(RuntimeError, match=no_data_found):
        setup_container(container, aws_session, **creation_kwargs)
def test_duplicate_input(container, aws_session, creation_kwargs,
                         input_data_config):
    input_data_config.append({
        # this is a duplicate channel
        "channelName": "single-file",
        "dataSource": {
            "s3DataSource": {
                "s3Uri": "s3://input_bucket/irrelevant"
            }
        },
    })
    creation_kwargs.update({"inputDataConfig": input_data_config})
    dupes_not_allowed = "Duplicate channel names not allowed for input data: single-file"
    with pytest.raises(ValueError, match=dupes_not_allowed):
        setup_container(container, aws_session, **creation_kwargs)
def test_basic_setup(container, aws_session, creation_kwargs, expected_envs):
    aws_session.parse_s3_uri.return_value = ["test_bucket", "test_location"]
    envs = setup_container(container, aws_session, **creation_kwargs)
    assert envs == expected_envs
    container.makedir.assert_any_call("/opt/ml/model")
    container.makedir.assert_any_call(
        expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"])
    assert container.makedir.call_count == 2
def test_temporary_credentials(container, aws_session, creation_kwargs,
                               expected_envs):
    aws_session.boto_session.get_credentials.return_value.token = "Test Token"
    expected_envs["AWS_SESSION_TOKEN"] = "Test Token"
    aws_session.parse_s3_uri.return_value = ["test_bucket", "test_location"]
    envs = setup_container(container, aws_session, **creation_kwargs)
    assert envs == expected_envs
    container.makedir.assert_any_call("/opt/ml/model")
    container.makedir.assert_any_call(
        expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"])
    assert container.makedir.call_count == 2
def test_input(container, aws_session, creation_kwargs, input_data_config):
    creation_kwargs.update({"inputDataConfig": input_data_config})
    setup_container(container, aws_session, **creation_kwargs)
    download_locations = [
        call[0][1] for call in aws_session.download_from_s3.call_args_list
    ]
    expected_downloads = [
        Path("single-file", "file-1.txt"),
        Path("directory-no-slash", "file-1.txt"),
        Path("directory-no-slash", "file-2.txt"),
        Path("directory-slash", "file-1.txt"),
        Path("directory-slash", "file-2.txt"),
        Path("directory-prefix", "input-dir", "file-1.txt"),
        Path("directory-prefix", "input-dir", "file-2.txt"),
        Path("files-prefix", "file-1.txt"),
        Path("files-prefix", "file-2.txt"),
    ]

    for download, expected_download in zip(download_locations,
                                           expected_downloads):
        assert download.endswith(str(expected_download))
def test_compressed_script_mode(container, aws_session, creation_kwargs,
                                expected_envs, compressed_script_mode_config):
    creation_kwargs["algorithmSpecification"] = compressed_script_mode_config
    expected_envs[
        "AMZN_BRAKET_SCRIPT_S3_URI"] = "s3://amazon-braket-jobs/job-path/my_archive.gzip"
    expected_envs["AMZN_BRAKET_SCRIPT_COMPRESSION_TYPE"] = "gzip"
    aws_session.parse_s3_uri.return_value = ["test_bucket", "test_location"]
    envs = setup_container(container, aws_session, **creation_kwargs)
    assert envs == expected_envs
    container.makedir.assert_any_call("/opt/ml/model")
    container.makedir.assert_any_call(
        expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"])
    assert container.makedir.call_count == 2
def test_hyperparameters(tempfile, json, container, aws_session,
                         creation_kwargs, expected_envs):
    with patch("builtins.open", mock_open()):
        tempfile.return_value.__enter__.return_value = "temporaryDir"
        creation_kwargs["hyperParameters"] = {"test": "hyper"}
        expected_envs[
            "AMZN_BRAKET_HP_FILE"] = "/opt/braket/input/config/hyperparameters.json"
        aws_session.parse_s3_uri.return_value = [
            "test_bucket", "test_location"
        ]
        envs = setup_container(container, aws_session, **creation_kwargs)
        assert envs == expected_envs
        container.makedir.assert_any_call("/opt/ml/model")
        container.makedir.assert_any_call(
            expected_envs["AMZN_BRAKET_CHECKPOINT_DIR"])
        assert container.makedir.call_count == 2
        container.copy_to.assert_called_with(
            os.path.join("temporaryDir", "hyperparameters.json"),
            "/opt/ml/input/config/hyperparameters.json",
        )
Example #8
0
    def create(
        cls,
        device: str,
        source_module: str,
        entry_point: str = None,
        image_uri: str = None,
        job_name: str = None,
        code_location: str = None,
        role_arn: str = None,
        hyperparameters: Dict[str, Any] = None,
        input_data: Union[str, Dict, S3DataSourceConfig] = None,
        output_data_config: OutputDataConfig = None,
        checkpoint_config: CheckpointConfig = None,
        aws_session: AwsSession = None,
    ) -> LocalQuantumJob:
        """Creates and runs job by setting up and running the customer script in a local
         docker container.

         Args:
            device (str): ARN for the AWS device which is primarily
                accessed for the execution of this job.

            source_module (str): Path (absolute, relative or an S3 URI) to a python module to be
                tarred and uploaded. If `source_module` is an S3 URI, it must point to a
                tar.gz file. Otherwise, source_module may be a file or directory.

            entry_point (str): A str that specifies the entry point of the job, relative to
                the source module. The entry point must be in the format
                `importable.module` or `importable.module:callable`. For example,
                `source_module.submodule:start_here` indicates the `start_here` function
                contained in `source_module.submodule`. If source_module is an S3 URI,
                entry point must be given. Default: source_module's name

            image_uri (str): A str that specifies the ECR image to use for executing the job.
                `image_uris.retrieve_image()` function may be used for retrieving the ECR image URIs
                for the containers supported by Braket. Default = `<Braket base image_uri>`.

            job_name (str): A str that specifies the name with which the job is created.
                Default: f'{image_uri_type}-{timestamp}'.

            code_location (str): The S3 prefix URI where custom code will be uploaded.
                Default: f's3://{default_bucket_name}/jobs/{job_name}/script'.

            role_arn (str): This field is currently not used for local jobs. Local jobs will use
                the current role's credentials. This may be subject to change.

            hyperparameters (Dict[str, Any]): Hyperparameters accessible to the job.
                The hyperparameters are made accessible as a Dict[str, str] to the job.
                For convenience, this accepts other types for keys and values, but `str()`
                is called to convert them before being passed on. Default: None.

            input_data (Union[str, S3DataSourceConfig, dict]): Information about the training
                data. Dictionary maps channel names to local paths or S3 URIs. Contents found
                at any local paths will be uploaded to S3 at
                f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local
                path, S3 URI, or S3DataSourceConfig is provided, it will be given a default
                channel name "input".
                Default: {}.

            output_data_config (OutputDataConfig): Specifies the location for the output of the job.
                Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data',
                kmsKeyId=None).

            checkpoint_config (CheckpointConfig): Configuration that specifies the location where
                checkpoint data is stored.
                Default: CheckpointConfig(localPath='/opt/jobs/checkpoints',
                s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints').

            aws_session (AwsSession): AwsSession for connecting to AWS Services.
                Default: AwsSession()

        Returns:
            LocalQuantumJob: The representation of a local Braket Job.
        """
        create_job_kwargs = prepare_quantum_job(
            device=device,
            source_module=source_module,
            entry_point=entry_point,
            image_uri=image_uri,
            job_name=job_name,
            code_location=code_location,
            role_arn=role_arn,
            hyperparameters=hyperparameters,
            input_data=input_data,
            output_data_config=output_data_config,
            checkpoint_config=checkpoint_config,
            aws_session=aws_session,
        )

        job_name = create_job_kwargs["jobName"]
        if os.path.isdir(job_name):
            raise ValueError(
                f"A local directory called {job_name} already exists. "
                f"Please use a different job name."
            )

        session = aws_session or AwsSession()
        algorithm_specification = create_job_kwargs["algorithmSpecification"]
        if "containerImage" in algorithm_specification:
            image_uri = algorithm_specification["containerImage"]["uri"]
        else:
            image_uri = retrieve_image(Framework.BASE, session.region)

        with _LocalJobContainer(image_uri) as container:
            env_variables = setup_container(container, session, **create_job_kwargs)
            container.run_local_job(env_variables)
            container.copy_from("/opt/ml/model", job_name)
            with open(os.path.join(job_name, "log.txt"), "w") as log_file:
                log_file.write(container.run_log)
            if "checkpointConfig" in create_job_kwargs:
                checkpoint_config = create_job_kwargs["checkpointConfig"]
                if "localPath" in checkpoint_config:
                    checkpoint_path = checkpoint_config["localPath"]
                    container.copy_from(checkpoint_path, os.path.join(job_name, "checkpoints"))
            run_log = container.run_log
        return LocalQuantumJob(f"local:job/{job_name}", run_log)