def test_region(): boto_region = "boto-region" braket_region = "braket-region" boto_session = Mock() boto_session.region_name = boto_region braket_client = Mock() braket_client.meta.region_name = braket_region assert ( AwsSession( boto_session=boto_session, ).region == boto_region ) assert ( AwsSession( braket_client=braket_client, ).region == braket_region ) regions_must_match = ( "Boto Session region and Braket Client region must match and currently " "they do not: Boto Session region is 'boto-region', but " "Braket Client region is 'braket-region'." ) with pytest.raises(ValueError, match=regions_must_match): AwsSession( boto_session=boto_session, braket_client=braket_client, )
def test_retrieve_s3_object_body_success(boto_session): bucket_name = "braket-integ-test" filename = "tasks/test_task_1.json" mock_resource = Mock() boto_session.resource.return_value = mock_resource mock_object = Mock() mock_resource.Object.return_value = mock_object mock_body_object = Mock() mock_object.get.return_value = {"Body": mock_body_object} mock_read_object = Mock() mock_body_object.read.return_value = mock_read_object mock_read_object.decode.return_value = json.dumps(TEST_S3_OBJ_CONTENTS) json.dumps(TEST_S3_OBJ_CONTENTS) aws_session = AwsSession(boto_session=boto_session) return_value = aws_session.retrieve_s3_object_body(bucket_name, filename) assert return_value == json.dumps(TEST_S3_OBJ_CONTENTS) boto_session.resource.assert_called_with("s3", config=None) config = Mock() AwsSession(boto_session=boto_session, config=config).retrieve_s3_object_body( bucket_name, filename ) boto_session.resource.assert_called_with("s3", config=config)
def test_get_device(boto_session): braket_client = Mock() return_val = {"deviceArn": "arn1", "deviceName": "name1"} braket_client.get_device.return_value = return_val aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) metadata = aws_session.get_device("arn1") assert return_val == metadata
def test_all_devices_price_search(): devices = AwsDevice.get_devices(statuses=["ONLINE", "OFFLINE"]) tasks = {} for region in AwsDevice.REGIONS: s = AwsSession(boto3.Session(region_name=region)) for device in devices: try: s.get_device(device.arn) # If we are here, device can create tasks in region details = { "shots": 100, "device": device.arn, "billed_duration": MIN_SIMULATOR_DURATION, "job_task": False, "status": "COMPLETED", } tasks[f"task:for:{device.name}:{region}"] = details.copy() details["job_task"] = True tasks[f"jobtask:for:{device.name}:{region}"] = details except s.braket_client.exceptions.ResourceNotFoundException: # device does not exist in region, so nothing to test pass t = Tracker() t._resources = tasks assert t.qpu_tasks_cost() + t.simulator_tasks_cost() > 0
def test_copy_explicit_session(boto_session_init, aws_explicit_session): boto_session_init.return_value = Mock() AwsSession.copy_session(aws_explicit_session, "us-west-2") boto_session_init.assert_called_with( aws_access_key_id="access key", aws_secret_access_key="secret key", aws_session_token="token", region_name="us-west-2", )
def aws_session(boto_session, braket_client, account_id): _aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) _aws_session._sts = Mock() _aws_session._sts.get_caller_identity.return_value = { "Account": account_id, } _aws_session._s3 = Mock() return _aws_session
def test_retrieve_s3_object_body_client_error(boto_session): bucket_name = "braket-integ-test" filename = "tasks/test_task_1.json" mock_resource = Mock() boto_session.resource.return_value = mock_resource mock_object = Mock() mock_resource.Object.return_value = mock_object mock_object.get.side_effect = ClientError( {"Error": {"Code": "ValidationException", "Message": "NoSuchKey"}}, "Operation" ) aws_session = AwsSession(boto_session=boto_session) aws_session.retrieve_s3_object_body(bucket_name, filename)
def _translate_creation_args(create_job_args): aws_session = create_job_args["aws_session"] create_job_args = defaultdict(lambda: None, **create_job_args) image_uri = create_job_args["image_uri"] job_name = create_job_args["job_name"] or _generate_default_job_name(image_uri) default_bucket = aws_session.default_bucket() code_location = create_job_args["code_location"] or AwsSession.construct_s3_uri( default_bucket, "jobs", job_name, "script" ) role_arn = create_job_args["role_arn"] or aws_session.get_default_jobs_role() device = create_job_args["device"] hyperparameters = create_job_args["hyperparameters"] or {} input_data = create_job_args["input_data"] or {} instance_config = create_job_args["instance_config"] or InstanceConfig() output_data_config = create_job_args["output_data_config"] or OutputDataConfig( s3Path=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "data") ) stopping_condition = create_job_args["stopping_condition"] or StoppingCondition() checkpoint_config = create_job_args["checkpoint_config"] or CheckpointConfig( s3Uri=AwsSession.construct_s3_uri(default_bucket, "jobs", job_name, "checkpoints") ) entry_point = create_job_args["entry_point"] source_module = create_job_args["source_module"] if not AwsSession.is_s3_uri(source_module): entry_point = entry_point or Path(source_module).stem algorithm_specification = { "scriptModeConfig": { "entryPoint": entry_point, "s3Uri": f"{code_location}/source.tar.gz", "compressionType": "GZIP", } } if image_uri: algorithm_specification["containerImage"] = {"uri": image_uri} tags = create_job_args.get("tags", {}) test_kwargs = { "jobName": job_name, "roleArn": role_arn, "algorithmSpecification": algorithm_specification, "inputDataConfig": _process_input_data(input_data, job_name, aws_session), "instanceConfig": asdict(instance_config), "outputDataConfig": asdict(output_data_config), "checkpointConfig": asdict(checkpoint_config), "deviceConfig": {"device": device}, "hyperParameters": hyperparameters, "stoppingCondition": asdict(stopping_condition), "tags": tags, } return test_kwargs
def test_copy_s3(aws_session): source_s3_uri = "s3://here/now" dest_s3_uri = "s3://there/then" source_bucket, source_key = AwsSession.parse_s3_uri(source_s3_uri) dest_bucket, dest_key = AwsSession.parse_s3_uri(dest_s3_uri) aws_session.copy_s3_object(source_s3_uri, dest_s3_uri) aws_session._s3.copy.assert_called_with( { "Bucket": source_bucket, "Key": source_key, }, dest_bucket, dest_key, )
def test_uses_supplied_braket_client(): boto_session = Mock() boto_session.region_name = "foobar" braket_client = Mock() aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) assert aws_session.braket_client == braket_client
def test_copy_session(boto_session_init, aws_session): boto_session_init.return_value = Mock() aws_session.braket_client._client_config.user_agent = "foo/bar" copied_session = AwsSession.copy_session(aws_session, "us-west-2") boto_session_init.assert_called_with(region_name="us-west-2") assert copied_session.braket_client._client_config.user_agent == "foo/bar" assert copied_session._default_bucket is None
def test_config(boto_session): config = Mock() AwsSession(boto_session=boto_session, config=config) boto_session.client.assert_any_call( "braket", config=config, endpoint_url="some-endpoint", )
def test_populates_user_agent(os_path_exists_mock, metadata_file_exists, initial_user_agent): boto_session = Mock() boto_session.region_name = "foobar" braket_client = Mock() braket_client.meta.region_name = "foobar" braket_client._client_config.user_agent = initial_user_agent nbi_metadata_path = "/opt/ml/metadata/resource-metadata.json" os_path_exists_mock.return_value = metadata_file_exists aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) expected_user_agent = ( f"{initial_user_agent} BraketSdk/{braket_sdk.__version__} " f"BraketSchemas/{braket_schemas.__version__} " f"NotebookInstance/{0 if metadata_file_exists else None}" ) os_path_exists_mock.assert_called_with(nbi_metadata_path) assert aws_session.braket_client._client_config.user_agent == expected_user_agent
def code_location(bucket, s3_prefix): return AwsSession.construct_s3_uri(bucket, s3_prefix, "script")
def checkpoint_config(bucket, s3_prefix): return CheckpointConfig( localPath="/opt/omega/checkpoints", s3Uri=AwsSession.construct_s3_uri(bucket, s3_prefix, "checkpoints"), )
def output_data_config(bucket, s3_prefix): return OutputDataConfig( s3Path=AwsSession.construct_s3_uri(bucket, s3_prefix, "output"), )
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import random import uuid from unittest.mock import Mock, PropertyMock, patch import pytest from common_test_utils import MockS3 from braket.aws import AwsQuantumTaskBatch, AwsSession from braket.circuits import Circuit from braket.tasks import GateModelQuantumTaskResult S3_TARGET = AwsSession.S3DestinationFolder("foo", "bar") @patch("braket.aws.aws_quantum_task.AwsQuantumTask.create") def test_creation(mock_create): task_mock = Mock() type(task_mock).id = PropertyMock(side_effect=uuid.uuid4) task_mock.state.return_value = "RUNNING" mock_create.return_value = task_mock batch_size = 10 batch = AwsQuantumTaskBatch(Mock(), "foo", _circuits(batch_size), S3_TARGET, 1000,
def test_initializes_boto_client_if_required(boto_session): AwsSession(boto_session=boto_session) boto_session.client.assert_any_call("braket", config=None, endpoint_url=None)
def test_parse_s3_uri(uri, bucket, key): assert bucket, key == AwsSession.parse_s3_uri(uri)
def test_config(boto_session): config = Mock() AwsSession(boto_session=boto_session, config=config) boto_session.client.assert_called_with("braket", config=config)
def test_default_bucket_env_variable(boto_session, braket_client): aws_session = AwsSession(boto_session=boto_session, braket_client=braket_client) assert aws_session.default_bucket() == "default_bucket_env"
def aws_session(boto_session): return AwsSession(boto_session=boto_session, braket_client=Mock())
def test_parse_s3_uri_invalid(uri): with pytest.raises(ValueError, match=f"Not a valid S3 uri: {uri}"): AwsSession.parse_s3_uri(uri)
def test_initializes_boto_client_if_required(boto_session): AwsSession(boto_session=boto_session, braket_client=None) boto_session.client.assert_called_with("braket")
def test_construct_s3_uri(bucket, dirs): parsed_bucket, parsed_key = AwsSession.parse_s3_uri(AwsSession.construct_s3_uri(bucket, *dirs)) assert parsed_bucket == bucket assert parsed_key == "/".join(dirs)
def test_is_s3_uri(string, valid): assert AwsSession.is_s3_uri(string) == valid
def test_copy_session_custom_default_bucket(mock_boto, aws_session): mock_boto.return_value.region_name = "us-test-1" aws_session._default_bucket = "my-own-default" aws_session._custom_default_bucket = True copied_session = AwsSession.copy_session(aws_session) assert copied_session._default_bucket == "my-own-default"
def source_module(request, bucket, s3_prefix): if request.param == "local_source": return "test-source-module" elif request.param == "s3_source": return AwsSession.construct_s3_uri(bucket, "test-source-prefix", "source.tar.gz")
def test_copy_session(boto_session_init, aws_session): boto_session_init.return_value = Mock() copied_session = AwsSession.copy_session(aws_session, "us-west-2") boto_session_init.assert_called_with(region_name="us-west-2") assert copied_session._default_bucket is None