Beispiel #1
0
def copy_parts(event, context):
    topic_arn = event["Records"][0]["Sns"]["TopicArn"]
    msg = json.loads(event["Records"][0]["Sns"]["Message"])
    blobstore_handle = dss.Config.get_blobstore_handle(
        platform_to_replica[msg["source_platform"]])
    source_url = blobstore_handle.generate_presigned_GET_url(
        bucket=msg["source_bucket"], key=msg["source_key"])
    futures = []
    gs = dss.Config.get_native_handle(Replica.gcp)
    with ThreadPoolExecutor(max_workers=4) as executor:
        for part in msg["parts"]:
            logger.info(log_msg.format(part=part, **msg))
            if msg["dest_platform"] == "s3":
                upload_url = "{host}/{bucket}/{key}?partNumber={part_num}&uploadId={mpu_id}".format(
                    host=clients.s3.meta.endpoint_url,
                    bucket=msg["dest_bucket"],
                    key=msg["dest_key"],
                    part_num=part["id"],
                    mpu_id=msg["mpu"])
            elif msg["dest_platform"] == "gs":
                assert len(msg["parts"]) == 1
                dest_blob_name = "{}.part{}".format(msg["dest_key"],
                                                    part["id"])
                dest_blob = gs.get_bucket(
                    msg["dest_bucket"]).blob(dest_blob_name)
                upload_url = dest_blob.create_resumable_upload_session(
                    size=part["end"] - part["start"] + 1)
            futures.append(
                executor.submit(copy_part, upload_url, source_url,
                                msg["dest_platform"], part))
    for future in futures:
        future.result()

    if msg["dest_platform"] == "s3":
        mpu = resources.s3.Bucket(msg["dest_bucket"]).Object(
            msg["dest_key"]).MultipartUpload(msg["mpu"])
        parts = list(mpu.parts.all())
    elif msg["dest_platform"] == "gs":
        part_names = [
            "{}.part{}".format(msg["dest_key"], p + 1)
            for p in range(msg["total_parts"])
        ]
        parts = [
            gs.get_bucket(msg["dest_bucket"]).get_blob(p) for p in part_names
        ]
        parts = [p for p in parts if p is not None]
    logger.info("Parts complete: {}".format(len(parts)))
    logger.info("Parts outstanding: {}".format(msg["total_parts"] -
                                               len(parts)))
    if msg["total_parts"] - len(
            parts) < parts_per_worker[msg["dest_platform"]] * 2:
        logger.info("Calling closer")
        send_sns_msg(
            ARN(topic_arn,
                resource=sns_topics["closer"][msg["dest_platform"]]), msg)
        logger.info("Called closer")
Beispiel #2
0
def dispatch_multipart_sync(source, dest, context):
    parts_for_worker = []
    futures = []
    total_size = source.blob.content_length if source.platform == "s3" else source.blob.size
    all_parts = list(enumerate(range(0, total_size, part_size[dest.platform])))
    mpu = dest.blob.initiate_multipart_upload(
        Metadata=source.blob.metadata or {}) if dest.platform == "s3" else None

    with ThreadPoolExecutor(max_workers=4) as executor:
        for part_id, part_start in all_parts:
            parts_for_worker.append(
                dict(id=part_id + 1,
                     start=part_start,
                     end=min(total_size - 1,
                             part_start + part_size[dest.platform] - 1),
                     total_parts=len(all_parts)))
            if len(parts_for_worker) >= parts_per_worker[
                    dest.platform] or part_id == all_parts[-1][0]:
                logger.info("Invoking dss-copy-parts with %s",
                            ", ".join(str(p["id"]) for p in parts_for_worker))
                sns_msg = dict(source_platform=source.platform,
                               source_bucket=source.bucket.name,
                               source_key=source.blob.key if source.platform
                               == "s3" else source.blob.name,
                               dest_platform=dest.platform,
                               dest_bucket=dest.bucket.name,
                               dest_key=dest.blob.key
                               if dest.platform == "s3" else dest.blob.name,
                               mpu=mpu.id if mpu else None,
                               parts=parts_for_worker,
                               total_parts=len(all_parts))
                sns_arn = ARN(context.invoked_function_arn,
                              service="sns",
                              resource=sns_topics["copy_parts"])
                futures.append(executor.submit(send_sns_msg, sns_arn, sns_msg))
                parts_for_worker = []
    for future in futures:
        future.result()
Beispiel #3
0
import typing
import logging

from dss.util.aws import ARN
from dss.util.aws.clients import stepfunctions  # type: ignore
from dss.util.aws import send_sns_msg

"""
The keys used to transfer step function invocation data over SNS to the dss-sfn-* Lambda. dss-sfn starts step function
execution and configured with DLQ for resiliency
"""
SFN_TEMPLATE_KEY = 'sfn_template'
SFN_EXECUTION_KEY = 'sfn_execution'
SFN_INPUT_KEY = 'sfn_input'

region = ARN.get_region()
stage = os.environ["DSS_DEPLOYMENT_STAGE"]
accountid = ARN.get_account_id()

sfn_sns_topic = f"dss-sfn-{stage}"
sfn_sns_topic_arn = f"arn:aws:sns:{region}:{accountid}:{sfn_sns_topic}"

logger = logging.getLogger(__name__)


def step_functions_arn(state_machine_name_template: str) -> str:
    """
    The ARN of a state machine, with name derived from `state_machine_name_template`, with string formatting to
    replace {stage} with the dss deployment stage.
    """
 def test_aws_utils(self):
     arn = ARN(service="sns", resource="my_topic")
     arn.get_region()
     arn.get_account_id()
     str(arn)