def _process_channel(location: str, job_name: str, aws_session: AwsSession, channel_name: str) -> S3DataSourceConfig: """ Convert a location to an S3DataSourceConfig, uploading local data to S3, if necessary. Args: location (str): Local prefix or S3 prefix. job_name (str): Job name. aws_session (AwsSession): AwsSession to be used for uploading local data. channel_name (str): Name of the channel. Returns: S3DataSourceConfig: S3DataSourceConfig for the channel. """ if AwsSession.is_s3_uri(location): return S3DataSourceConfig(location) else: # local prefix "path/to/prefix" will be mapped to # s3://bucket/jobs/job-name/data/input/prefix location_name = Path(location).name s3_prefix = AwsSession.construct_s3_uri(aws_session.default_bucket(), "jobs", job_name, "data", channel_name, location_name) aws_session.upload_local_data(location, s3_prefix) return S3DataSourceConfig(s3_prefix)
def prepare_quantum_job( device: str, source_module: str, entry_point: str = None, image_uri: str = None, job_name: str = None, code_location: str = None, role_arn: str = None, hyperparameters: Dict[str, Any] = None, input_data: Union[str, Dict, S3DataSourceConfig] = None, instance_config: InstanceConfig = None, distribution: str = None, stopping_condition: StoppingCondition = None, output_data_config: OutputDataConfig = None, copy_checkpoints_from_job: str = None, checkpoint_config: CheckpointConfig = None, aws_session: AwsSession = None, tags: Dict[str, str] = None, ) -> Dict: """Creates a job by invoking the Braket CreateJob API. Args: device (str): ARN for the AWS device which is primarily accessed for the execution of this job. source_module (str): Path (absolute, relative or an S3 URI) to a python module to be tarred and uploaded. If `source_module` is an S3 URI, it must point to a tar.gz file. Otherwise, source_module may be a file or directory. entry_point (str): A str that specifies the entry point of the job, relative to the source module. The entry point must be in the format `importable.module` or `importable.module:callable`. For example, `source_module.submodule:start_here` indicates the `start_here` function contained in `source_module.submodule`. If source_module is an S3 URI, entry point must be given. Default: source_module's name image_uri (str): A str that specifies the ECR image to use for executing the job. `image_uris.retrieve_image()` function may be used for retrieving the ECR image URIs for the containers supported by Braket. Default = `<Braket base image_uri>`. job_name (str): A str that specifies the name with which the job is created. Default: f'{image_uri_type}-{timestamp}'. code_location (str): The S3 prefix URI where custom code will be uploaded. Default: f's3://{default_bucket_name}/jobs/{job_name}/script'. role_arn (str): A str providing the IAM role ARN used to execute the script. Default: IAM role returned by AwsSession's `get_default_jobs_role()`. hyperparameters (Dict[str, Any]): Hyperparameters accessible to the job. The hyperparameters are made accessible as a Dict[str, str] to the job. For convenience, this accepts other types for keys and values, but `str()` is called to convert them before being passed on. Default: None. input_data (Union[str, Dict, S3DataSourceConfig]): Information about the training data. Dictionary maps channel names to local paths or S3 URIs. Contents found at any local paths will be uploaded to S3 at f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local path, S3 URI, or S3DataSourceConfig is provided, it will be given a default channel name "input". Default: {}. instance_config (InstanceConfig): Configuration of the instances to be used to execute the job. Default: InstanceConfig(instanceType='ml.m5.large', instanceCount=1, volumeSizeInGB=30, volumeKmsKey=None). distribution (str): A str that specifies how the job should be distributed. If set to "data_parallel", the hyperparameters for the job will be set to use data parallelism features for PyTorch or TensorFlow. Default: None. stopping_condition (StoppingCondition): The maximum length of time, in seconds, and the maximum number of tasks that a job can run before being forcefully stopped. Default: StoppingCondition(maxRuntimeInSeconds=5 * 24 * 60 * 60). output_data_config (OutputDataConfig): Specifies the location for the output of the job. Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', kmsKeyId=None). copy_checkpoints_from_job (str): A str that specifies the job ARN whose checkpoint you want to use in the current job. Specifying this value will copy over the checkpoint data from `use_checkpoints_from_job`'s checkpoint_config s3Uri to the current job's checkpoint_config s3Uri, making it available at checkpoint_config.localPath during the job execution. Default: None checkpoint_config (CheckpointConfig): Configuration that specifies the location where checkpoint data is stored. Default: CheckpointConfig(localPath='/opt/jobs/checkpoints', s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints'). aws_session (AwsSession): AwsSession for connecting to AWS Services. Default: AwsSession() tags (Dict[str, str]): Dict specifying the key-value pairs for tagging this job. Default: {}. Returns: Dict: Job tracking the execution on Amazon Braket. Raises: ValueError: Raises ValueError if the parameters are not valid. """ param_datatype_map = { "instance_config": (instance_config, InstanceConfig), "stopping_condition": (stopping_condition, StoppingCondition), "output_data_config": (output_data_config, OutputDataConfig), "checkpoint_config": (checkpoint_config, CheckpointConfig), } _validate_params(param_datatype_map) aws_session = aws_session or AwsSession() device_config = DeviceConfig(device) job_name = job_name or _generate_default_job_name(image_uri) role_arn = role_arn or os.getenv("BRAKET_JOBS_ROLE_ARN", aws_session.get_default_jobs_role()) hyperparameters = hyperparameters or {} hyperparameters = { str(key): str(value) for key, value in hyperparameters.items() } input_data = input_data or {} tags = tags or {} default_bucket = aws_session.default_bucket() input_data_list = _process_input_data(input_data, job_name, aws_session) instance_config = instance_config or InstanceConfig() stopping_condition = stopping_condition or StoppingCondition() output_data_config = output_data_config or OutputDataConfig() checkpoint_config = checkpoint_config or CheckpointConfig() code_location = code_location or AwsSession.construct_s3_uri( default_bucket, "jobs", job_name, "script", ) if AwsSession.is_s3_uri(source_module): _process_s3_source_module(source_module, entry_point, aws_session, code_location) else: # if entry point is None, it will be set to default here entry_point = _process_local_source_module(source_module, entry_point, aws_session, code_location) algorithm_specification = { "scriptModeConfig": { "entryPoint": entry_point, "s3Uri": f"{code_location}/source.tar.gz", "compressionType": "GZIP", } } if image_uri: algorithm_specification["containerImage"] = {"uri": image_uri} if not output_data_config.s3Path: output_data_config.s3Path = AwsSession.construct_s3_uri( default_bucket, "jobs", job_name, "data", ) if not checkpoint_config.s3Uri: checkpoint_config.s3Uri = AwsSession.construct_s3_uri( default_bucket, "jobs", job_name, "checkpoints", ) if copy_checkpoints_from_job: checkpoints_to_copy = aws_session.get_job( copy_checkpoints_from_job)["checkpointConfig"]["s3Uri"] aws_session.copy_s3_directory(checkpoints_to_copy, checkpoint_config.s3Uri) if distribution == "data_parallel": distributed_hyperparams = { "sagemaker_distributed_dataparallel_enabled": "true", "sagemaker_instance_type": instance_config.instanceType, } hyperparameters.update(distributed_hyperparams) create_job_kwargs = { "jobName": job_name, "roleArn": role_arn, "algorithmSpecification": algorithm_specification, "inputDataConfig": input_data_list, "instanceConfig": asdict(instance_config), "outputDataConfig": asdict(output_data_config), "checkpointConfig": asdict(checkpoint_config), "deviceConfig": asdict(device_config), "hyperParameters": hyperparameters, "stoppingCondition": asdict(stopping_condition), "tags": tags, } return create_job_kwargs