def _get_env_default_vars(aws_session: AwsSession, **creation_kwargs) -> Dict[str, str]: """This function gets the remaining 'simple' env variables, that don't require any additional logic to determine what they are or when they should be added as env variables. Returns: (Dict[str, str]): The set of key/value pairs that should be added as environment variables to the running container. """ job_name = creation_kwargs["jobName"] bucket, location = AwsSession.parse_s3_uri( creation_kwargs["outputDataConfig"]["s3Path"]) return { "AWS_DEFAULT_REGION": aws_session.region, "AMZN_BRAKET_JOB_NAME": job_name, "AMZN_BRAKET_DEVICE_ARN": creation_kwargs["deviceConfig"]["device"], "AMZN_BRAKET_JOB_RESULTS_DIR": "/opt/braket/model", "AMZN_BRAKET_CHECKPOINT_DIR": creation_kwargs["checkpointConfig"]["localPath"], "AMZN_BRAKET_OUT_S3_BUCKET": bucket, "AMZN_BRAKET_TASK_RESULTS_S3_URI": f"s3://{bucket}/jobs/{job_name}/tasks", "AMZN_BRAKET_JOB_RESULTS_S3_PATH": str(Path(location, job_name, "output").as_posix()), }
def run_batch( self, task_specifications: List[Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram]], s3_destination_folder: Optional[AwsSession.S3DestinationFolder] = None, shots: Optional[int] = None, max_parallel: Optional[int] = None, max_connections: int = AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT, poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, *aws_quantum_task_args, **aws_quantum_task_kwargs, ) -> AwsQuantumTaskBatch: """Executes a batch of tasks in parallel Args: task_specifications (List[Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram]]): List of circuits or annealing problems to run on device. s3_destination_folder (Optional[S3DestinationFolder]): The S3 location to save the tasks' results to. Default is `<default_bucket>/tasks` if evoked outside of a Braket Job, `<Job Bucket>/jobs/<job name>/tasks` if evoked inside of a Braket Job. shots (Optional[int]): The number of times to run the circuit or annealing problem. Default is 1000 for QPUs and 0 for simulators. max_parallel (Optional[int]): The maximum number of tasks to run on AWS in parallel. Batch creation will fail if this value is greater than the maximum allowed concurrent tasks on the device. Default: 10 max_connections (int): The maximum number of connections in the Boto3 connection pool. Also the maximum number of thread pool workers for the batch. Default: 100 poll_timeout_seconds (float): The polling timeout for `AwsQuantumTask.result()`, in seconds. Default: 5 days. poll_interval_seconds (float): The polling interval for results in seconds. Default: 1 second. Returns: AwsQuantumTaskBatch: A batch containing all of the tasks run See Also: `braket.aws.aws_quantum_task_batch.AwsQuantumTaskBatch` """ return AwsQuantumTaskBatch( AwsSession.copy_session(self._aws_session, max_connections=max_connections), self._arn, task_specifications, s3_destination_folder or ( AwsSession.parse_s3_uri(os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_URI")) if "AMZN_BRAKET_TASK_RESULTS_S3_URI" in os.environ else None ) or (self._aws_session.default_bucket(), "tasks"), shots if shots is not None else self._default_shots, max_parallel=max_parallel if max_parallel is not None else self._default_max_parallel, max_workers=max_connections, poll_timeout_seconds=poll_timeout_seconds, poll_interval_seconds=poll_interval_seconds, *aws_quantum_task_args, **aws_quantum_task_kwargs, )
def _download_input_data( aws_session: AwsSession, download_dir: str, input_data: Dict[str, Any], ) -> None: """Downloads input data for a job. Args: aws_session (AwsSession): AwsSession for connecting to AWS Services. download_dir (str): The directory path to download to. input_data (Dict[str, Any]): One of the input data in the boto3 input parameters for running a Braket Job. """ # If s3 prefix is the full name of a directory and all keys are inside # that directory, the contents of said directory will be copied into a # directory with the same name as the channel. This behavior is the same # whether or not s3 prefix ends with a "/". Moreover, if s3 prefix ends # with a "/", this is certainly the behavior to expect, since it can only # match a directory. # If s3 prefix matches any files exactly, or matches as a prefix of any # files or directories, then all files and directories matching s3 prefix # will be copied into a directory with the same name as the channel. channel_name = input_data["channelName"] s3_uri_prefix = input_data["dataSource"]["s3DataSource"]["s3Uri"] bucket, prefix = AwsSession.parse_s3_uri(s3_uri_prefix) s3_keys = aws_session.list_keys(bucket, prefix) top_level = prefix if _is_dir(prefix, s3_keys) else str( Path(prefix).parent) found_item = False try: Path(download_dir, channel_name).mkdir() except FileExistsError: raise ValueError( f"Duplicate channel names not allowed for input data: {channel_name}" ) for s3_key in s3_keys: relative_key = Path(s3_key).relative_to(top_level) download_path = Path(download_dir, channel_name, relative_key) if not s3_key.endswith("/"): download_path.parent.mkdir(parents=True, exist_ok=True) aws_session.download_from_s3( AwsSession.construct_s3_uri(bucket, s3_key), str(download_path)) found_item = True if not found_item: raise RuntimeError(f"No data found for channel '{channel_name}'")
def run( self, task_specification: Union[Circuit, Problem], s3_destination_folder: Optional[AwsSession.S3DestinationFolder] = None, shots: Optional[int] = None, poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, *aws_quantum_task_args, **aws_quantum_task_kwargs, ) -> AwsQuantumTask: """ Run a quantum task specification on this device. A task can be a circuit or an annealing problem. Args: task_specification (Union[Circuit, Problem]): Specification of task (circuit or annealing problem) to run on device. s3_destination_folder (AwsSession.S3DestinationFolder, optional): The S3 location to save the task's results to. Default is `<default_bucket>/tasks` if evoked outside of a Braket Job, `<Job Bucket>/jobs/<job name>/tasks` if evoked inside of a Braket Job. shots (int, optional): The number of times to run the circuit or annealing problem. Default is 1000 for QPUs and 0 for simulators. poll_timeout_seconds (float): The polling timeout for `AwsQuantumTask.result()`, in seconds. Default: 5 days. poll_interval_seconds (float): The polling interval for `AwsQuantumTask.result()`, in seconds. Default: 1 second. *aws_quantum_task_args: Variable length positional arguments for `braket.aws.aws_quantum_task.AwsQuantumTask.create()`. **aws_quantum_task_kwargs: Variable length keyword arguments for `braket.aws.aws_quantum_task.AwsQuantumTask.create()`. Returns: AwsQuantumTask: An AwsQuantumTask that tracks the execution on the device. Examples: >>> circuit = Circuit().h(0).cnot(0, 1) >>> device = AwsDevice("arn1") >>> device.run(circuit, ("bucket-foo", "key-bar")) >>> circuit = Circuit().h(0).cnot(0, 1) >>> device = AwsDevice("arn2") >>> device.run(task_specification=circuit, >>> s3_destination_folder=("bucket-foo", "key-bar")) >>> circuit = Circuit().h(0).cnot(0, 1) >>> device = AwsDevice("arn3") >>> device.run(task_specification=circuit, >>> s3_destination_folder=("bucket-foo", "key-bar"), disable_qubit_rewiring=True) >>> problem = Problem( >>> ProblemType.ISING, >>> linear={1: 3.14}, >>> quadratic={(1, 2): 10.08}, >>> ) >>> device = AwsDevice("arn4") >>> device.run(problem, ("bucket-foo", "key-bar"), >>> device_parameters={ >>> "providerLevelParameters": {"postprocessingType": "SAMPLING"}} >>> ) See Also: `braket.aws.aws_quantum_task.AwsQuantumTask.create()` """ return AwsQuantumTask.create( self._aws_session, self._arn, task_specification, s3_destination_folder or ( AwsSession.parse_s3_uri(os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_URI")) if "AMZN_BRAKET_TASK_RESULTS_S3_URI" in os.environ else None ) or (self._aws_session.default_bucket(), "tasks"), shots if shots is not None else self._default_shots, poll_timeout_seconds=poll_timeout_seconds, poll_interval_seconds=poll_interval_seconds, *aws_quantum_task_args, **aws_quantum_task_kwargs, )