def test_region():
    boto_region = "boto-region"
    braket_region = "braket-region"

    boto_session = Mock()
    boto_session.region_name = boto_region
    braket_client = Mock()
    braket_client.meta.region_name = braket_region

    assert (
        AwsSession(
            boto_session=boto_session,
        ).region
        == boto_region
    )

    assert (
        AwsSession(
            braket_client=braket_client,
        ).region
        == braket_region
    )

    regions_must_match = (
        "Boto Session region and Braket Client region must match and currently "
        "they do not: Boto Session region is 'boto-region', but "
        "Braket Client region is 'braket-region'."
    )
    with pytest.raises(ValueError, match=regions_must_match):
        AwsSession(
            boto_session=boto_session,
            braket_client=braket_client,
        )
def test_retrieve_s3_object_body_success(boto_session):
    bucket_name = "braket-integ-test"
    filename = "tasks/test_task_1.json"

    mock_resource = Mock()
    boto_session.resource.return_value = mock_resource
    mock_object = Mock()
    mock_resource.Object.return_value = mock_object
    mock_body_object = Mock()
    mock_object.get.return_value = {"Body": mock_body_object}
    mock_read_object = Mock()
    mock_body_object.read.return_value = mock_read_object
    mock_read_object.decode.return_value = json.dumps(TEST_S3_OBJ_CONTENTS)
    json.dumps(TEST_S3_OBJ_CONTENTS)

    aws_session = AwsSession(boto_session=boto_session)
    return_value = aws_session.retrieve_s3_object_body(bucket_name, filename)
    assert return_value == json.dumps(TEST_S3_OBJ_CONTENTS)
    boto_session.resource.assert_called_with("s3", config=None)

    config = Mock()
    AwsSession(boto_session=boto_session, config=config).retrieve_s3_object_body(
        bucket_name, filename
    )
    boto_session.resource.assert_called_with("s3", config=config)
def test_get_device(boto_session):
    braket_client = Mock()
    return_val = {"deviceArn": "arn1", "deviceName": "name1"}
    braket_client.get_device.return_value = return_val
    aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client)
    metadata = aws_session.get_device("arn1")
    assert return_val == metadata
def test_all_devices_price_search():
    devices = AwsDevice.get_devices(statuses=["ONLINE", "OFFLINE"])

    tasks = {}
    for region in AwsDevice.REGIONS:
        s = AwsSession(boto3.Session(region_name=region))
        for device in devices:
            try:
                s.get_device(device.arn)

                # If we are here, device can create tasks in region
                details = {
                    "shots": 100,
                    "device": device.arn,
                    "billed_duration": MIN_SIMULATOR_DURATION,
                    "job_task": False,
                    "status": "COMPLETED",
                }
                tasks[f"task:for:{device.name}:{region}"] = details.copy()
                details["job_task"] = True
                tasks[f"jobtask:for:{device.name}:{region}"] = details
            except s.braket_client.exceptions.ResourceNotFoundException:
                # device does not exist in region, so nothing to test
                pass

    t = Tracker()
    t._resources = tasks

    assert t.qpu_tasks_cost() + t.simulator_tasks_cost() > 0
def test_copy_explicit_session(boto_session_init, aws_explicit_session):
    boto_session_init.return_value = Mock()
    AwsSession.copy_session(aws_explicit_session, "us-west-2")
    boto_session_init.assert_called_with(
        aws_access_key_id="access key",
        aws_secret_access_key="secret key",
        aws_session_token="token",
        region_name="us-west-2",
    )
def aws_session(boto_session, braket_client, account_id):
    _aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client)

    _aws_session._sts = Mock()
    _aws_session._sts.get_caller_identity.return_value = {
        "Account": account_id,
    }

    _aws_session._s3 = Mock()
    return _aws_session
def test_retrieve_s3_object_body_client_error(boto_session):
    bucket_name = "braket-integ-test"
    filename = "tasks/test_task_1.json"

    mock_resource = Mock()
    boto_session.resource.return_value = mock_resource
    mock_object = Mock()
    mock_resource.Object.return_value = mock_object
    mock_object.get.side_effect = ClientError(
        {"Error": {"Code": "ValidationException", "Message": "NoSuchKey"}}, "Operation"
    )
    aws_session = AwsSession(boto_session=boto_session)
    aws_session.retrieve_s3_object_body(bucket_name, filename)
Esempio n. 8
0
def _translate_creation_args(create_job_args):
    aws_session = create_job_args["aws_session"]
    create_job_args = defaultdict(lambda: None, **create_job_args)
    image_uri = create_job_args["image_uri"]
    job_name = create_job_args["job_name"] or _generate_default_job_name(image_uri)
    default_bucket = aws_session.default_bucket()
    code_location = create_job_args["code_location"] or AwsSession.construct_s3_uri(
        default_bucket, "jobs", job_name, "script"
    )
    role_arn = create_job_args["role_arn"] or aws_session.get_default_jobs_role()
    device = create_job_args["device"]
    hyperparameters = create_job_args["hyperparameters"] or {}
    input_data = create_job_args["input_data"] or {}
    instance_config = create_job_args["instance_config"] or InstanceConfig()
    output_data_config = create_job_args["output_data_config"] or OutputDataConfig(
        s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "data")
    )
    stopping_condition = create_job_args["stopping_condition"] or StoppingCondition()
    checkpoint_config = create_job_args["checkpoint_config"] or CheckpointConfig(
        s3Uri=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "checkpoints")
    )
    entry_point = create_job_args["entry_point"]
    source_module = create_job_args["source_module"]
    if not AwsSession.is_s3_uri(source_module):
        entry_point = entry_point or Path(source_module).stem
    algorithm_specification = {
        "scriptModeConfig": {
            "entryPoint": entry_point,
            "s3Uri": f"{code_location}/source.tar.gz",
            "compressionType": "GZIP",
        }
    }
    if image_uri:
        algorithm_specification["containerImage"] = {"uri": image_uri}
    tags = create_job_args.get("tags", {})

    test_kwargs = {
        "jobName": job_name,
        "roleArn": role_arn,
        "algorithmSpecification": algorithm_specification,
        "inputDataConfig": _process_input_data(input_data, job_name, aws_session),
        "instanceConfig": asdict(instance_config),
        "outputDataConfig": asdict(output_data_config),
        "checkpointConfig": asdict(checkpoint_config),
        "deviceConfig": {"device": device},
        "hyperParameters": hyperparameters,
        "stoppingCondition": asdict(stopping_condition),
        "tags": tags,
    }

    return test_kwargs
def test_copy_s3(aws_session):
    source_s3_uri = "s3://here/now"
    dest_s3_uri = "s3://there/then"
    source_bucket, source_key = AwsSession.parse_s3_uri(source_s3_uri)
    dest_bucket, dest_key = AwsSession.parse_s3_uri(dest_s3_uri)
    aws_session.copy_s3_object(source_s3_uri, dest_s3_uri)
    aws_session._s3.copy.assert_called_with(
        {
            "Bucket": source_bucket,
            "Key": source_key,
        },
        dest_bucket,
        dest_key,
    )
def test_uses_supplied_braket_client():
    boto_session = Mock()
    boto_session.region_name = "foobar"
    braket_client = Mock()
    aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client)

    assert aws_session.braket_client == braket_client
def test_copy_session(boto_session_init, aws_session):
    boto_session_init.return_value = Mock()
    aws_session.braket_client._client_config.user_agent = "foo/bar"
    copied_session = AwsSession.copy_session(aws_session, "us-west-2")
    boto_session_init.assert_called_with(region_name="us-west-2")
    assert copied_session.braket_client._client_config.user_agent == "foo/bar"
    assert copied_session._default_bucket is None
def test_config(boto_session):
    config = Mock()
    AwsSession(boto_session=boto_session, config=config)
    boto_session.client.assert_any_call(
        "braket",
        config=config,
        endpoint_url="some-endpoint",
    )
def test_populates_user_agent(os_path_exists_mock, metadata_file_exists, initial_user_agent):
    boto_session = Mock()
    boto_session.region_name = "foobar"
    braket_client = Mock()
    braket_client.meta.region_name = "foobar"
    braket_client._client_config.user_agent = initial_user_agent
    nbi_metadata_path = "/opt/ml/metadata/resource-metadata.json"
    os_path_exists_mock.return_value = metadata_file_exists
    aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client)
    expected_user_agent = (
        f"{initial_user_agent} BraketSdk/{braket_sdk.__version__} "
        f"BraketSchemas/{braket_schemas.__version__} "
        f"NotebookInstance/{0 if metadata_file_exists else None}"
    )
    os_path_exists_mock.assert_called_with(nbi_metadata_path)
    assert aws_session.braket_client._client_config.user_agent == expected_user_agent
Esempio n. 14
0
def code_location(bucket, s3_prefix):
    return AwsSession.construct_s3_uri(bucket, s3_prefix, "script")
Esempio n. 15
0
def checkpoint_config(bucket, s3_prefix):
    return CheckpointConfig(
        localPath="/opt/omega/checkpoints",
        s3Uri=AwsSession.construct_s3_uri(bucket, s3_prefix, "checkpoints"),
    )
Esempio n. 16
0
def output_data_config(bucket, s3_prefix):
    return OutputDataConfig(
        s3Path=AwsSession.construct_s3_uri(bucket, s3_prefix, "output"),
    )
Esempio n. 17
0
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import random
import uuid
from unittest.mock import Mock, PropertyMock, patch

import pytest
from common_test_utils import MockS3

from braket.aws import AwsQuantumTaskBatch, AwsSession
from braket.circuits import Circuit
from braket.tasks import GateModelQuantumTaskResult

S3_TARGET = AwsSession.S3DestinationFolder("foo", "bar")


@patch("braket.aws.aws_quantum_task.AwsQuantumTask.create")
def test_creation(mock_create):
    task_mock = Mock()
    type(task_mock).id = PropertyMock(side_effect=uuid.uuid4)
    task_mock.state.return_value = "RUNNING"
    mock_create.return_value = task_mock

    batch_size = 10
    batch = AwsQuantumTaskBatch(Mock(),
                                "foo",
                                _circuits(batch_size),
                                S3_TARGET,
                                1000,
def test_initializes_boto_client_if_required(boto_session):
    AwsSession(boto_session=boto_session)
    boto_session.client.assert_any_call("braket", config=None, endpoint_url=None)
def test_parse_s3_uri(uri, bucket, key):
    assert bucket, key == AwsSession.parse_s3_uri(uri)
Esempio n. 20
0
def test_config(boto_session):
    config = Mock()
    AwsSession(boto_session=boto_session, config=config)
    boto_session.client.assert_called_with("braket", config=config)
def test_default_bucket_env_variable(boto_session, braket_client):
    aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client)
    assert aws_session.default_bucket() == "default_bucket_env"
def aws_session(boto_session):
    return AwsSession(boto_session=boto_session, braket_client=Mock())
def test_parse_s3_uri_invalid(uri):
    with pytest.raises(ValueError, match=f"Not a valid S3 uri: {uri}"):
        AwsSession.parse_s3_uri(uri)
def test_initializes_boto_client_if_required(boto_session):
    AwsSession(boto_session=boto_session, braket_client=None)
    boto_session.client.assert_called_with("braket")
def test_construct_s3_uri(bucket, dirs):
    parsed_bucket, parsed_key = AwsSession.parse_s3_uri(AwsSession.construct_s3_uri(bucket, *dirs))
    assert parsed_bucket == bucket
    assert parsed_key == "/".join(dirs)
def test_is_s3_uri(string, valid):
    assert AwsSession.is_s3_uri(string) == valid
def test_copy_session_custom_default_bucket(mock_boto, aws_session):
    mock_boto.return_value.region_name = "us-test-1"
    aws_session._default_bucket = "my-own-default"
    aws_session._custom_default_bucket = True
    copied_session = AwsSession.copy_session(aws_session)
    assert copied_session._default_bucket == "my-own-default"
Esempio n. 28
0
def source_module(request, bucket, s3_prefix):
    if request.param == "local_source":
        return "test-source-module"
    elif request.param == "s3_source":
        return AwsSession.construct_s3_uri(bucket, "test-source-prefix",
                                           "source.tar.gz")
Esempio n. 29
0
def test_copy_session(boto_session_init, aws_session):
    boto_session_init.return_value = Mock()
    copied_session = AwsSession.copy_session(aws_session, "us-west-2")
    boto_session_init.assert_called_with(region_name="us-west-2")
    assert copied_session._default_bucket is None