def upload(src, dst, gz, session: sagemaker.Session, root='.'):
    dst = cli_argument(dst, session=session)
    if not os.path.exists(src):
        raise click.UsageError("Source must exist")
    if not dst.startswith('s3://'):
        if dst.startswith('/'):
            dst = dst[1:]
        bucket = session.default_bucket()
        dst = 's3://{}/{}'.format(bucket, dst)
    url = urlparse(dst)
    assert url.scheme == 's3'
    bucket = url.netloc
    key = url.path
    if key.startswith('/'):
        key = key[1:]
    if os.path.isfile(src):
        if gz:
            raise click.UsageError(
                "Option gz is only valid for source directories")
        s3 = session.boto_session.client('s3')
        s3.upload_file(src, bucket, key)
    elif os.path.isdir(src):
        if gz:
            if not re.match(".*\\.(tar\\.gz||tgz)$", dst, re.IGNORECASE):
                raise click.UsageError(
                    "Destination should end in .tar.gz or tgz")
            s3_dst = os.path.dirname(dst)
            file_name = os.path.basename(dst)
            with _tmpdir() as tmp:
                p = os.path.join(tmp, file_name)
                with tarfile.open(p, 'w:gz') as arc:
                    arc.add(name=src, arcname=root, recursive=True)
                s3 = session.boto_session.client('s3')
                s3.upload_file(p, bucket, key)
        else:
            S3Uploader.upload(local_path=src,
                              desired_s3_uri=dst,
                              sagemaker_session=session)
    else:
        raise click.UsageError("Source must be file or directory")
Beispiel #2
0
def model_create(job,
                 model_artifact,
                 name,
                 session: sagemaker.Session,
                 inference_image,
                 inference_image_path,
                 inference_image_accounts,
                 role,
                 force,
                 multimodel=False,
                 accelerator_type=None):
    job = cli_argument(job, session=session)
    name = cli_argument(name, session=session)
    model_artifact = cli_argument(model_artifact, session=session)
    image_config = Image(tag=inference_image,
                         path=inference_image_path,
                         accounts=inference_image_accounts)
    image_uri = ecr_ensure_image(image=image_config,
                                 session=session.boto_session)
    if (job and model_artifact) or (not (job or model_artifact)):
        raise click.UsageError('Specify one of job_name or model_artifact')
    if model_artifact and not name:
        raise click.UsageError('name is required if job is not provided')
    iam = session.boto_session.client('iam')
    client = session.boto_session.client('sagemaker')
    role = ensure_inference_role(iam=iam, role_name=role)
    if job:
        client = session.boto_session.client('sagemaker')
        model_artifact = training_describe(
            job_name=job,
            field='ModelArtifacts.S3ModelArtifacts',
            session=session)
        if not name:
            name = job
        print("Creating model [{}] from job [{}] artifact [{}]".format(
            name, job, model_artifact))
    else:
        if not model_artifact.startswith('s3://'):
            if model_artifact.startswith('/'):
                model_artifact = model_artifact[1:]
            bucket = session.default_bucket()
            model_artifact = 's3://{}/{}'.format(bucket, model_artifact)
        print("Creating model [{}] from artifact [{}]".format(
            name, model_artifact))

    if model_exists(name=name, client=client):
        if force:
            print("Deleting existing model")
            model_delete(name=name, client=client)
        else:
            raise click.UsageError('Specify force if overwriting model')
    model = sagemaker.Model(
        image_uri=image_uri,
        model_data=model_artifact,
        role=role,
        predictor_cls=None,
        env=None,
        name=name,
        # vpc_config=None,
        sagemaker_session=session,
        # enable_network_isolation=False,
        # model_kms_key=None
    )
    container_def = sagemaker.container_def(
        model.image_uri,
        model.model_data,
        model.env,
        container_mode='MultiModel' if multimodel else 'SingleModel')
    """
    client.create_model(
    ModelName='string',
    PrimaryContainer={
        'ContainerHostname': 'string',
        'Image': 'string',
        'ImageConfig': {
            'RepositoryAccessMode': 'Platform'|'Vpc'
        },
        'Mode': 'SingleModel'|'MultiModel',
        'ModelDataUrl': 'string',
        'Environment': {
            'string': 'string'
        },
        'ModelPackageName': 'string'
    },
    """

    # self._ensure_base_name_if_needed(container_def["Image"])
    # self._set_model_name_if_needed()

    enable_network_isolation = model.enable_network_isolation()

    # self._init_sagemaker_session_if_does_not_exist(instance_type)
    session.create_model(
        model.name,
        model.role,
        container_def,
        vpc_config=model.vpc_config,
        enable_network_isolation=enable_network_isolation,
        # tags=tags,
    )