Exemplo n.º 1
0
def build_sagemaker_args(
    launch_project: LaunchProject,
    account_id: str,
    aws_tag: Optional[str] = None,
) -> Dict[str, Any]:
    sagemaker_args = {}
    given_sagemaker_args = launch_project.resource_args.get("sagemaker")
    if given_sagemaker_args is None:
        raise LaunchError(
            "No sagemaker args specified. Specify sagemaker args in resource_args"
        )
    sagemaker_args["TrainingJobName"] = (
        given_sagemaker_args.get("TrainingJobName") or launch_project.run_id)

    sagemaker_args[
        "AlgorithmSpecification"] = merge_aws_tag_with_algorithm_specification(
            given_sagemaker_args.get(
                "AlgorithmSpecification",
                given_sagemaker_args.get("algorithm_specification"),
            ),
            aws_tag,
        )

    sagemaker_args["RoleArn"] = get_role_arn(given_sagemaker_args, account_id)

    camel_case_args = {
        to_camel_case(key): item
        for key, item in given_sagemaker_args.items()
    }
    sagemaker_args = {
        **camel_case_args,
        **sagemaker_args,
    }

    if sagemaker_args.get("OutputDataConfig") is None:
        raise LaunchError(
            "Sagemaker launcher requires an OutputDataConfig Sagemaker resource argument"
        )

    if sagemaker_args.get("ResourceConfig") is None:
        raise LaunchError(
            "Sagemaker launcher requires a ResourceConfig Sagemaker resource argument"
        )

    if sagemaker_args.get("StoppingCondition") is None:
        raise LaunchError(
            "Sagemaker launcher requires a StoppingCondition Sagemaker resource argument"
        )

    # remove args that were passed in for launch but not passed to sagemaker
    sagemaker_args.pop("EcrRepoName", None)
    sagemaker_args.pop("region", None)
    sagemaker_args.pop("profile", None)

    # clear the args that are None so they are not passed
    filtered_args = {k: v for k, v in sagemaker_args.items() if v is not None}

    return filtered_args
Exemplo n.º 2
0
def aws_ecr_login(region: str, registry: str) -> Optional[str]:
    pw_command = ["aws", "ecr", "get-login-password", "--region", region]
    try:
        pw = run_shell(pw_command)
    except subprocess.CalledProcessError:
        raise LaunchError(
            "Unable to get login password. Please ensure you have AWS credentials configured"
        )
    try:
        docker_login_process = docker.login("AWS", pw, registry)
    except Exception:
        raise LaunchError(f"Failed to login to ECR {registry}")
    return docker_login_process
Exemplo n.º 3
0
def generate_docker_image(
    api: Api,
    launch_project: LaunchProject,
    image_uri: str,
    entrypoint: EntryPoint,
    docker_args: Dict[str, Any],
    runner_type: str,
) -> str:
    entry_cmd = get_entry_point_command(entrypoint,
                                        launch_project.override_args)[0]
    dockerfile_str = generate_dockerfile(api, launch_project, entry_cmd,
                                         runner_type)
    create_metadata_file(
        launch_project,
        image_uri,
        sanitize_wandb_api_key(entry_cmd),
        docker_args,
        sanitize_wandb_api_key(dockerfile_str),
    )
    build_ctx_path = _create_docker_build_ctx(launch_project, dockerfile_str)
    dockerfile = os.path.join(build_ctx_path, _GENERATED_DOCKERFILE_NAME)
    try:
        image = docker.build(tags=[image_uri],
                             file=dockerfile,
                             context_path=build_ctx_path)
    except DockerError as e:
        raise LaunchError(
            "Error communicating with docker client: {}".format(e))

    try:
        os.remove(build_ctx_path)
    except Exception:
        _logger.info("Temporary docker context file %s was not deleted.",
                     build_ctx_path)
    return image
Exemplo n.º 4
0
def fetch_wandb_project_run_info(entity: str, project: str, run_name: str,
                                 api: Api) -> Any:
    _logger.info("Fetching run info...")
    try:
        result = api.get_run_info(entity, project, run_name)
    except CommError:
        result = None
    if result is None:
        raise LaunchError(
            f"Run info is invalid or doesn't exist for {api.settings('base_url')}/{entity}/{project}/runs/{run_name}"
        )
    if result.get("codePath") is None:
        # TODO: we don't currently expose codePath in the runInfo endpoint, this downloads
        # it from wandb-metadata.json if we can.
        metadata = api.download_url(project,
                                    "wandb-metadata.json",
                                    run=run_name,
                                    entity=entity)
        if metadata is not None:
            _, response = api.download_file(metadata["url"])
            data = response.json()
            result["codePath"] = data.get("codePath")
            result["cudaVersion"] = data.get("cuda", None)

    if result.get("args") is not None:
        result["args"] = util._user_args_to_dict(result["args"])
    return result
Exemplo n.º 5
0
def get_role_arn(sagemaker_args: Dict[str, Any], account_id: str) -> str:
    role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
    if role_arn is None or not isinstance(role_arn, str):
        raise LaunchError(
            "AWS sagemaker require a string RoleArn set this by adding a `RoleArn` key to the sagemaker"
            "field of resource_args")
    if role_arn.startswith("arn:aws:iam::"):
        return role_arn

    return f"arn:aws:iam::{account_id}:role/{role_arn}"
Exemplo n.º 6
0
def get_aws_credentials(sagemaker_args: Dict[str, Any]) -> Tuple[str, str]:
    access_key = os.environ.get("AWS_ACCESS_KEY_ID")
    secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
    if (access_key is None or secret_key is None
            and os.path.exists(os.path.expanduser("~/.aws/credentials"))):
        profile = sagemaker_args.get("profile") or "default"
        config = configparser.ConfigParser()
        config.read(os.path.expanduser("~/.aws/credentials"))
        try:
            access_key = config.get(profile, "aws_access_key_id")
            secret_key = config.get(profile, "aws_secret_access_key")
        except (configparser.NoOptionError, configparser.NoSectionError):
            raise LaunchError(
                "Unable to get aws credentials from ~/.aws/credentials. "
                "Please set aws credentials in environments variables, or "
                "check your credentials in ~/.aws/credentials. Use resource "
                "args to specify the profile using 'profile'")

    if access_key is None or secret_key is None:
        raise LaunchError("AWS credentials not found")
    return access_key, secret_key
Exemplo n.º 7
0
def load_backend(
    backend_name: str, api: Api, backend_config: Dict[str, Any]
) -> AbstractRunner:
    # Static backends
    if backend_name in WANDB_RUNNERS:
        return WANDB_RUNNERS[backend_name](api, backend_config)

    raise LaunchError(
        "Resource name not among available resources. Available resources: {} ".format(
            ",".join(list(WANDB_RUNNERS.keys()))
        )
    )
Exemplo n.º 8
0
def get_region(sagemaker_args: Dict[str, Any]) -> str:
    region = sagemaker_args.get("region")
    if region is None:
        region = os.environ.get("AWS_DEFAULT_REGION")
    if region is None and os.path.exists(os.path.expanduser("~/.aws/config")):
        config = configparser.ConfigParser()
        config.read(os.path.expanduser("~/.aws/config"))
        section = sagemaker_args.get("profile") or "default"
        try:
            region = config.get(section, "region")
        except (configparser.NoOptionError, configparser.NoSectionError):
            raise LaunchError(
                "Unable to detemine default region from ~/.aws/config. "
                "Please specify region in resource args or specify config "
                "section as 'profile'")

    if region is None:
        raise LaunchError(
            "AWS region not specified and ~/.aws/config not found. Configure AWS"
        )
    assert isinstance(region, str)
    return region
Exemplo n.º 9
0
def launch_sagemaker_job(
    launch_project: LaunchProject,
    sagemaker_args: Dict[str, Any],
    sagemaker_client: "boto3.Client",
) -> SagemakerSubmittedRun:
    training_job_name = sagemaker_args.get(
        "TrainingJobName") or launch_project.run_id
    resp = sagemaker_client.create_training_job(**sagemaker_args)

    if resp.get("TrainingJobArn") is None:
        raise LaunchError("Unable to create training job")

    run = SagemakerSubmittedRun(training_job_name, sagemaker_client)
    wandb.termlog("Run job submitted with arn: {}".format(
        resp.get("TrainingJobArn")))
    url = "https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{job_name}".format(
        region=sagemaker_client.meta.region_name, job_name=training_job_name)
    wandb.termlog(f"See training job status at: {url}")
    return run
Exemplo n.º 10
0
def merge_aws_tag_with_algorithm_specification(
        algorithm_specification: Optional[Dict[str, Any]],
        aws_tag: Optional[str]) -> Dict[str, Any]:
    """
    AWS Sagemaker algorithms require a training image and an input mode.
    If the user does not specify the specification themselves, define the spec
    minimally using these two fields. Otherwise, if they specify the AlgorithmSpecification
    set the training image if it is not set.
    """
    if algorithm_specification is None:
        return {
            "TrainingImage": aws_tag,
            "TrainingInputMode": "File",
        }
    elif algorithm_specification.get("TrainingImage") is None:
        algorithm_specification["TrainingImage"] = aws_tag
    if algorithm_specification["TrainingImage"] is None:
        raise LaunchError("Failed determine tag for training image")
    return algorithm_specification
Exemplo n.º 11
0
def pull_docker_image(docker_image: str) -> None:
    """Pulls the requested docker image"""
    try:
        docker.run(["docker", "pull", docker_image])
    except DockerError as e:
        raise LaunchError("Docker server returned error: {}".format(e))
Exemplo n.º 12
0
    def run(self, launch_project: LaunchProject) -> Optional[AbstractRun]:
        _logger.info("using AWSSagemakerRunner")

        boto3 = get_module(
            "boto3", "AWSSagemakerRunner requires boto3 to be installed")

        validate_docker_installation()
        given_sagemaker_args = launch_project.resource_args.get("sagemaker")
        if given_sagemaker_args is None:
            raise LaunchError(
                "No sagemaker args specified. Specify sagemaker args in resource_args"
            )
        if (given_sagemaker_args.get("EcrRepoName",
                                     given_sagemaker_args.get("ecr_repo_name"))
                is None):
            raise LaunchError(
                "AWS sagemaker requires an ECR Repo to push the container to "
                "set this by adding a `EcrRepoName` key to the sagemaker"
                "field of resource_args")

        region = get_region(given_sagemaker_args)
        access_key, secret_key = get_aws_credentials(given_sagemaker_args)
        client = boto3.client("sts",
                              aws_access_key_id=access_key,
                              aws_secret_access_key=secret_key)
        account_id = client.get_caller_identity()["Account"]

        # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
        if (given_sagemaker_args.get("AlgorithmSpecification",
                                     {}).get("TrainingImage") is not None):
            wandb.termwarn(
                "Launching sagemaker job with user provided ECR image, this image will not be able to swap artifacts"
            )
            sagemaker_client = boto3.client(
                "sagemaker",
                region_name=region,
                aws_access_key_id=access_key,
                aws_secret_access_key=secret_key,
            )
            sagemaker_args = build_sagemaker_args(launch_project, account_id)
            _logger.info(
                f"Launching sagemaker job on user supplied image with args: {sagemaker_args}"
            )
            run = launch_sagemaker_job(launch_project, sagemaker_args,
                                       sagemaker_client)
            if self.backend_config[PROJECT_SYNCHRONOUS]:
                run.wait()
            return run

        _logger.info("Connecting to AWS ECR Client")
        ecr_client = boto3.client(
            "ecr",
            region_name=region,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key,
        )
        token = ecr_client.get_authorization_token()

        ecr_repo_name = given_sagemaker_args.get(
            "EcrRepoName", given_sagemaker_args.get("ecr_repo_name"))
        aws_registry = (token["authorizationData"][0]["proxyEndpoint"].replace(
            "https://", "") + f"/{ecr_repo_name}")

        if self.backend_config[PROJECT_DOCKER_ARGS]:
            wandb.termwarn(
                "Docker args are not supported for Sagemaker Resource. Not using docker args"
            )

        entry_point = launch_project.get_single_entry_point()

        if launch_project.docker_image:
            _logger.info("Pulling user provided docker image")
            pull_docker_image(launch_project.docker_image)
        else:
            # build our own image
            image_uri = construct_local_image_uri(launch_project)
            _logger.info("Building docker image")
            image = generate_docker_image(self._api, launch_project, image_uri,
                                          entry_point, {}, "sagemaker")

        _logger.info("Logging in to AWS ECR")
        login_resp = aws_ecr_login(region, aws_registry)
        if login_resp is None or "Login Succeeded" not in login_resp:
            raise LaunchError(
                f"Unable to login to ECR, response: {login_resp}")

        aws_tag = f"{aws_registry}:{launch_project.run_id}"
        docker.tag(image, aws_tag)

        wandb.termlog(f"Pushing image {image} to registry {aws_registry}")
        push_resp = docker.push(aws_registry, launch_project.run_id)
        if push_resp is None:
            raise LaunchError("Failed to push image to repository")
        if f"The push refers to repository [{aws_registry}]" not in push_resp:
            raise LaunchError(
                f"Unable to push image to ECR, response: {push_resp}")

        if self.backend_config.get("runQueueItemId"):
            try:
                self._api.ack_run_queue_item(
                    self.backend_config["runQueueItemId"],
                    launch_project.run_id)
            except CommError:
                wandb.termerror(
                    "Error acking run queue item. Item lease may have ended or another process may have acked it."
                )
                return None
        _logger.info("Connecting to sagemaker client")

        sagemaker_client = boto3.client(
            "sagemaker",
            region_name=region,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key,
        )

        command_args = get_entry_point_command(entry_point,
                                               launch_project.override_args)
        command_args = list(
            itertools.chain(*[ca.split(" ") for ca in command_args]))
        wandb.termlog("Launching run on sagemaker with entrypoint: {}".format(
            " ".join(command_args)))

        sagemaker_args = build_sagemaker_args(launch_project, account_id,
                                              aws_tag)
        _logger.info(f"Launching sagemaker job with args: {sagemaker_args}")
        run = launch_sagemaker_job(launch_project, sagemaker_args,
                                   sagemaker_client)
        if self.backend_config[PROJECT_SYNCHRONOUS]:
            run.wait()
        return run
Exemplo n.º 13
0
    def __init__(
        self,
        uri: str,
        api: Api,
        launch_spec: Dict[str, Any],
        target_entity: str,
        target_project: str,
        name: Optional[str],
        docker_config: Dict[str, Any],
        git_info: Dict[str, str],
        overrides: Dict[str, Any],
        resource: str,
        resource_args: Dict[str, Any],
        cuda: Optional[bool],
    ):
        if utils.is_bare_wandb_uri(uri):
            uri = api.settings("base_url") + uri
            _logger.info(f"Updating uri with base uri: {uri}")
        self.uri = uri
        self.api = api
        self.launch_spec = launch_spec
        self.target_entity = target_entity
        self.target_project = target_project
        self.name = name
        self.build_image: bool = docker_config.get("build_image", False)
        self.python_version: Optional[str] = docker_config.get(
            "python_version")
        self.cuda_version: Optional[str] = docker_config.get("cuda_version")
        self._base_image: Optional[str] = docker_config.get("base_image")
        self.docker_image: Optional[str] = docker_config.get("docker_image")
        uid = RESOURCE_UID_MAP.get(resource, 1000)
        if self._base_image:
            uid = docker.get_image_uid(self._base_image)
            _logger.info(f"Retrieved base image uid {uid}")
        self.docker_user_id: int = docker_config.get("user_id", uid)
        self.git_version: Optional[str] = git_info.get("version")
        self.git_repo: Optional[str] = git_info.get("repo")
        self.override_args: Dict[str, Any] = overrides.get("args", {})
        self.override_config: Dict[str, Any] = overrides.get("run_config", {})
        self.resource = resource
        self.resource_args = resource_args
        self.deps_type: Optional[str] = None
        self.cuda = cuda
        self._runtime: Optional[str] = None
        self.run_id = generate_id()
        self._entry_points: Dict[str, EntryPoint] = {
        }  # todo: keep multiple entrypoint support?
        if "entry_point" in overrides:
            _logger.info("Adding override entry point")
            self.add_entry_point(overrides["entry_point"])
        if utils._is_wandb_uri(self.uri):
            _logger.info(f"URI {self.uri} indicates a wandb uri")
            self.source = LaunchSource.WANDB
            self.project_dir = tempfile.mkdtemp()
        elif utils._is_git_uri(self.uri):
            _logger.info(f"URI {self.uri} indicates a git uri")
            self.source = LaunchSource.GIT
            self.project_dir = tempfile.mkdtemp()
        else:
            _logger.info(f"URI {self.uri} indicates a local uri")
            # assume local
            if not os.path.exists(self.uri):
                raise LaunchError(
                    "Assumed URI supplied is a local path but path is not valid"
                )
            self.source = LaunchSource.LOCAL
            self.project_dir = self.uri
        if launch_spec.get("resource_args"):
            self.resource_args = launch_spec["resource_args"]

        self.aux_dir = tempfile.mkdtemp()
        self.clear_parameter_run_config_collisions()
Exemplo n.º 14
0
    def _fetch_project_local(self, internal_api: Api) -> None:
        """Fetch a project (either wandb run or git repo) into a local directory, returning the path to the local project directory."""
        assert self.source != LaunchSource.LOCAL
        _logger.info("Fetching project locally...")
        if utils._is_wandb_uri(self.uri):
            source_entity, source_project, source_run_name = utils.parse_wandb_uri(
                self.uri)
            run_info = utils.fetch_wandb_project_run_info(
                source_entity, source_project, source_run_name, internal_api)
            entry_point = run_info.get("codePath", run_info["program"])

            if run_info.get("cudaVersion"):
                original_cuda_version = ".".join(
                    run_info["cudaVersion"].split(".")[:2])

                if self.cuda is None:
                    # only set cuda on by default if cuda is None (unspecified), not False (user specifically requested cpu image)
                    wandb.termlog(
                        "Original wandb run {} was run with cuda version {}. Enabling cuda builds by default; to build on a CPU-only image, run again with --cuda=False"
                        .format(source_run_name, original_cuda_version))
                    self.cuda_version = original_cuda_version
                    self.cuda = True
                if (self.cuda and self.cuda_version
                        and self.cuda_version != original_cuda_version):
                    wandb.termlog(
                        "Specified cuda version {} differs from original cuda version {}. Running with specified version {}"
                        .format(self.cuda_version, original_cuda_version,
                                self.cuda_version))

            downloaded_code_artifact = utils.check_and_download_code_artifacts(
                source_entity,
                source_project,
                source_run_name,
                internal_api,
                self.project_dir,
            )

            if downloaded_code_artifact:
                self.build_image = True
            elif not downloaded_code_artifact:
                if not run_info["git"]:
                    raise ExecutionError(
                        "Reproducing a run requires either an associated git repo or a code artifact logged with `run.log_code()`"
                    )
                utils._fetch_git_repo(
                    self.project_dir,
                    run_info["git"]["remote"],
                    run_info["git"]["commit"],
                )
                patch = utils.fetch_project_diff(source_entity, source_project,
                                                 source_run_name, internal_api)

                if patch:
                    utils.apply_patch(patch, self.project_dir)
                # For cases where the entry point wasn't checked into git
                if not os.path.exists(
                        os.path.join(self.project_dir, entry_point)):
                    downloaded_entrypoint = utils.download_entry_point(
                        source_entity,
                        source_project,
                        source_run_name,
                        internal_api,
                        entry_point,
                        self.project_dir,
                    )
                    if not downloaded_entrypoint:
                        raise LaunchError(
                            f"Entrypoint: {entry_point} does not exist, "
                            "and could not be downloaded. Please specify the entrypoint for this run."
                        )
                    # if the entrypoint is downloaded and inserted into the project dir
                    # need to rebuild image with new code
                    self.build_image = True

            if entry_point.endswith("ipynb"):
                entry_point = utils.convert_jupyter_notebook_to_script(
                    entry_point, self.project_dir)

            # Download any frozen requirements
            utils.download_wandb_python_deps(
                source_entity,
                source_project,
                source_run_name,
                internal_api,
                self.project_dir,
            )

            # Specify the python runtime for jupyter2docker
            self.python_version = run_info.get("python", "3")

            if not self._entry_points:
                self.add_entry_point(entry_point)
            self.override_args = utils.merge_parameters(
                self.override_args, run_info["args"])
        else:
            assert utils._GIT_URI_REGEX.match(
                self.uri), ("Non-wandb URI %s should be a Git URI" % self.uri)

            if not self._entry_points:
                wandb.termlog(
                    "Entry point for repo not specified, defaulting to main.py"
                )
                self.add_entry_point("main.py")
            utils._fetch_git_repo(self.project_dir, self.uri, self.git_version)