Example #1
0
def _get_env_default_vars(aws_session: AwsSession,
                          **creation_kwargs) -> Dict[str, str]:
    """This function gets the remaining 'simple' env variables, that don't require any
     additional logic to determine what they are or when they should be added as env variables.

    Returns:
        (Dict[str, str]): The set of key/value pairs that should be added as environment variables
        to the running container.
    """
    job_name = creation_kwargs["jobName"]
    bucket, location = AwsSession.parse_s3_uri(
        creation_kwargs["outputDataConfig"]["s3Path"])
    return {
        "AWS_DEFAULT_REGION":
        aws_session.region,
        "AMZN_BRAKET_JOB_NAME":
        job_name,
        "AMZN_BRAKET_DEVICE_ARN":
        creation_kwargs["deviceConfig"]["device"],
        "AMZN_BRAKET_JOB_RESULTS_DIR":
        "/opt/braket/model",
        "AMZN_BRAKET_CHECKPOINT_DIR":
        creation_kwargs["checkpointConfig"]["localPath"],
        "AMZN_BRAKET_OUT_S3_BUCKET":
        bucket,
        "AMZN_BRAKET_TASK_RESULTS_S3_URI":
        f"s3://{bucket}/jobs/{job_name}/tasks",
        "AMZN_BRAKET_JOB_RESULTS_S3_PATH":
        str(Path(location, job_name, "output").as_posix()),
    }
    def run_batch(
        self,
        task_specifications: List[Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram]],
        s3_destination_folder: Optional[AwsSession.S3DestinationFolder] = None,
        shots: Optional[int] = None,
        max_parallel: Optional[int] = None,
        max_connections: int = AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT,
        poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
        poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
        *aws_quantum_task_args,
        **aws_quantum_task_kwargs,
    ) -> AwsQuantumTaskBatch:
        """Executes a batch of tasks in parallel

        Args:
            task_specifications (List[Union[Circuit, Problem, OpenQasmProgram, BlackbirdProgram]]):
                List of  circuits or annealing problems to run on device.
            s3_destination_folder (Optional[S3DestinationFolder]): The S3 location to
                save the tasks' results to. Default is `<default_bucket>/tasks` if evoked
                outside of a Braket Job, `<Job Bucket>/jobs/<job name>/tasks` if evoked inside of
                a Braket Job.
            shots (Optional[int]): The number of times to run the circuit or annealing problem.
                Default is 1000 for QPUs and 0 for simulators.
            max_parallel (Optional[int]): The maximum number of tasks to run on AWS in parallel.
                Batch creation will fail if this value is greater than the maximum allowed
                concurrent tasks on the device. Default: 10
            max_connections (int): The maximum number of connections in the Boto3 connection pool.
                Also the maximum number of thread pool workers for the batch. Default: 100
            poll_timeout_seconds (float): The polling timeout for `AwsQuantumTask.result()`,
                in seconds. Default: 5 days.
            poll_interval_seconds (float): The polling interval for results in seconds.
                Default: 1 second.

        Returns:
            AwsQuantumTaskBatch: A batch containing all of the tasks run

        See Also:
            `braket.aws.aws_quantum_task_batch.AwsQuantumTaskBatch`
        """
        return AwsQuantumTaskBatch(
            AwsSession.copy_session(self._aws_session, max_connections=max_connections),
            self._arn,
            task_specifications,
            s3_destination_folder
            or (
                AwsSession.parse_s3_uri(os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_URI"))
                if "AMZN_BRAKET_TASK_RESULTS_S3_URI" in os.environ
                else None
            )
            or (self._aws_session.default_bucket(), "tasks"),
            shots if shots is not None else self._default_shots,
            max_parallel=max_parallel if max_parallel is not None else self._default_max_parallel,
            max_workers=max_connections,
            poll_timeout_seconds=poll_timeout_seconds,
            poll_interval_seconds=poll_interval_seconds,
            *aws_quantum_task_args,
            **aws_quantum_task_kwargs,
        )
Example #3
0
def _download_input_data(
    aws_session: AwsSession,
    download_dir: str,
    input_data: Dict[str, Any],
) -> None:
    """Downloads input data for a job.

    Args:
        aws_session (AwsSession): AwsSession for connecting to AWS Services.
        download_dir (str): The directory path to download to.
        input_data (Dict[str, Any]): One of the input data in the boto3 input parameters for
            running a Braket Job.
    """
    # If s3 prefix is the full name of a directory and all keys are inside
    # that directory, the contents of said directory will be copied into a
    # directory with the same name as the channel. This behavior is the same
    # whether or not s3 prefix ends with a "/". Moreover, if s3 prefix ends
    # with a "/", this is certainly the behavior to expect, since it can only
    # match a directory.
    # If s3 prefix matches any files exactly, or matches as a prefix of any
    # files or directories, then all files and directories matching s3 prefix
    # will be copied into a directory with the same name as the channel.
    channel_name = input_data["channelName"]
    s3_uri_prefix = input_data["dataSource"]["s3DataSource"]["s3Uri"]
    bucket, prefix = AwsSession.parse_s3_uri(s3_uri_prefix)
    s3_keys = aws_session.list_keys(bucket, prefix)
    top_level = prefix if _is_dir(prefix, s3_keys) else str(
        Path(prefix).parent)
    found_item = False
    try:
        Path(download_dir, channel_name).mkdir()
    except FileExistsError:
        raise ValueError(
            f"Duplicate channel names not allowed for input data: {channel_name}"
        )
    for s3_key in s3_keys:
        relative_key = Path(s3_key).relative_to(top_level)
        download_path = Path(download_dir, channel_name, relative_key)
        if not s3_key.endswith("/"):
            download_path.parent.mkdir(parents=True, exist_ok=True)
            aws_session.download_from_s3(
                AwsSession.construct_s3_uri(bucket, s3_key),
                str(download_path))
            found_item = True
    if not found_item:
        raise RuntimeError(f"No data found for channel '{channel_name}'")
Example #4
0
    def run(
        self,
        task_specification: Union[Circuit, Problem],
        s3_destination_folder: Optional[AwsSession.S3DestinationFolder] = None,
        shots: Optional[int] = None,
        poll_timeout_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
        poll_interval_seconds: float = AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
        *aws_quantum_task_args,
        **aws_quantum_task_kwargs,
    ) -> AwsQuantumTask:
        """
        Run a quantum task specification on this device. A task can be a circuit or an
        annealing problem.

        Args:
            task_specification (Union[Circuit, Problem]): Specification of task
                (circuit or annealing problem) to run on device.
            s3_destination_folder (AwsSession.S3DestinationFolder, optional): The S3 location to
                save the task's results to. Default is `<default_bucket>/tasks` if evoked
                outside of a Braket Job, `<Job Bucket>/jobs/<job name>/tasks` if evoked inside of
                a Braket Job.
            shots (int, optional): The number of times to run the circuit or annealing problem.
                Default is 1000 for QPUs and 0 for simulators.
            poll_timeout_seconds (float): The polling timeout for `AwsQuantumTask.result()`,
                in seconds. Default: 5 days.
            poll_interval_seconds (float): The polling interval for `AwsQuantumTask.result()`,
                in seconds. Default: 1 second.
            *aws_quantum_task_args: Variable length positional arguments for
                `braket.aws.aws_quantum_task.AwsQuantumTask.create()`.
            **aws_quantum_task_kwargs: Variable length keyword arguments for
                `braket.aws.aws_quantum_task.AwsQuantumTask.create()`.

        Returns:
            AwsQuantumTask: An AwsQuantumTask that tracks the execution on the device.

        Examples:
            >>> circuit = Circuit().h(0).cnot(0, 1)
            >>> device = AwsDevice("arn1")
            >>> device.run(circuit, ("bucket-foo", "key-bar"))

            >>> circuit = Circuit().h(0).cnot(0, 1)
            >>> device = AwsDevice("arn2")
            >>> device.run(task_specification=circuit,
            >>>     s3_destination_folder=("bucket-foo", "key-bar"))

            >>> circuit = Circuit().h(0).cnot(0, 1)
            >>> device = AwsDevice("arn3")
            >>> device.run(task_specification=circuit,
            >>>     s3_destination_folder=("bucket-foo", "key-bar"), disable_qubit_rewiring=True)

            >>> problem = Problem(
            >>>     ProblemType.ISING,
            >>>     linear={1: 3.14},
            >>>     quadratic={(1, 2): 10.08},
            >>> )
            >>> device = AwsDevice("arn4")
            >>> device.run(problem, ("bucket-foo", "key-bar"),
            >>>     device_parameters={
            >>>         "providerLevelParameters": {"postprocessingType": "SAMPLING"}}
            >>> )

        See Also:
            `braket.aws.aws_quantum_task.AwsQuantumTask.create()`
        """
        return AwsQuantumTask.create(
            self._aws_session,
            self._arn,
            task_specification,
            s3_destination_folder
            or (
                AwsSession.parse_s3_uri(os.environ.get("AMZN_BRAKET_TASK_RESULTS_S3_URI"))
                if "AMZN_BRAKET_TASK_RESULTS_S3_URI" in os.environ
                else None
            )
            or (self._aws_session.default_bucket(), "tasks"),
            shots if shots is not None else self._default_shots,
            poll_timeout_seconds=poll_timeout_seconds,
            poll_interval_seconds=poll_interval_seconds,
            *aws_quantum_task_args,
            **aws_quantum_task_kwargs,
        )