예제 #1
0
def ssh(job):
    config = get_cloudexec_config()
    run_command = None
    for instance in get_all_instances():
        exp_name = get_tag_value(instance, 'ExpName')
        if exp_name == job:
            ip_addr = instance['PublicIpAddress']
            exp_group = get_tag_value(instance, 'ExpGroup')
            key_name = config.aws_key_pairs[instance['Region']]
            key_path = local_ec2_key_pair_path(key_name)
            ec2_path = os.path.join(
                config.ec2_project_root,
                "data/local/{exp_group}/{exp_name}".format(exp_group=exp_group,
                                                           exp_name=exp_name))
            command = " ".join([
                "ssh", "-oStrictHostKeyChecking=no", "-oConnectTimeout=10",
                "-i", key_path, "-t", "ubuntu@" + ip_addr,
                "'cd %s; exec bash -l'" % ec2_path
            ])
            print(command)

            def run_command():
                return os.system(command)

    if run_command is not None:
        while True:
            if run_command() == 0:
                break
            else:
                time.sleep(1)
                print("Retrying")
    else:
        print("Not found!")
예제 #2
0
def get_all_instances():
    config = get_cloudexec_config()
    with multiprocessing.Pool(10) as pool:
        all_instances = sum(pool.map(_collect_instances, config.aws_regions),
                            [])
    all_instances = [
        x for x in all_instances
        if get_tag_value(x, 'Owner') == config.attendee_id
    ]
    return all_instances
예제 #3
0
def jobs(verbose):
    jobs = []
    config = get_cloudexec_config()
    for instance in get_all_instances():
        exp_name = get_tag_value(instance, 'ExpName')
        if exp_name is None:
            exp_name = '(None)'
        if (exp_name
                is not None) and (instance['State']['Name'] != 'terminated'):
            jobs.append((exp_name, instance['Placement']['AvailabilityZone']))

    for job in sorted(jobs):
        print(*job)
예제 #4
0
def get_clients():
    clients = []
    config = get_cloudexec_config()
    regions = config.aws_regions
    for region in regions:
        client = boto3.client(
            "ec2",
            region_name=region,
            aws_access_key_id=config.aws_access_key,
            aws_secret_access_key=config.aws_access_secret,
        )
        client.region = region
        clients.append(client)
    return clients
예제 #5
0
def _collect_instances(region):
    try:
        config = get_cloudexec_config()
        client = boto3.client(
            "ec2",
            region_name=region,
            aws_access_key_id=config.aws_access_key,
            aws_secret_access_key=config.aws_access_secret,
        )
        print("Collecting instances in region", region)
        instances = sum([
            x['Instances']
            for x in client.describe_instances(Filters=[{
                'Name': 'instance-state-name',
                'Values': ['running']
            }])['Reservations']
        ], [])
        for instance in instances:
            instance['Region'] = region
        return instances
    except Exception as e:
        import traceback
        traceback.print_exc()
        raise e
예제 #6
0
#!/usr/bin/env python
from cloudexec import get_cloudexec_config, get_project_root
import boto3
import botocore.exceptions
import os

if __name__ == "__main__":
    config = get_cloudexec_config()

    key_names = dict()

    for region in config.aws_regions:
        ec2_client = boto3.client(
            "ec2",
            region_name=region,
            aws_access_key_id=config.aws_access_key,
            aws_secret_access_key=config.aws_access_secret,
        )

        key_name = "{attendee_id}_{region}".format(
            attendee_id=config.attendee_id, region=region)

        key_names[region] = key_name

        print("Trying to create key pair with name %s" % key_name)
        import cloudexec
        file_name = cloudexec.local_ec2_key_pair_path(key_name)

        try:
            key_pair = ec2_client.create_key_pair(KeyName=key_name)
        except botocore.exceptions.ClientError as e:
예제 #7
0
def spot_history(instance_type, duration):
    config = get_cloudexec_config()
    num_duration = int(duration[:-1])
    if re.match(r"^(\d+)d$", duration):
        duration = int(duration[:-1]) * 86400
        print(
            "Querying maximum spot price in each zone within the past {duration} day(s)..."
            .format(duration=num_duration))
    elif re.match(r"^(\d+)h$", duration):
        duration = int(duration[:-1]) * 3600
        print(
            "Querying maximum spot price in each zone within the past {duration} hour(s)..."
            .format(duration=num_duration))
    elif re.match(r"^(\d+)w$", duration):
        duration = int(duration[:-1]) * 86400 * 7
        print(
            "Querying maximum spot price in each zone within the past {duration} week(s)..."
            .format(duration=num_duration))
    elif re.match(r"^(\d+)m$", duration):
        duration = int(duration[:-1]) * 86400 * 30
        print(
            "Querying maximum spot price in each zone within the past {duration} month(s)..."
            .format(duration=num_duration))
    elif re.match(r"^(\d+)s$", duration):
        duration = int(duration[:-1])
        print(
            "Querying maximum spot price in each zone within the past {duration} second(s)..."
            .format(duration=num_duration))
    else:
        raise ValueError(
            "Unrecognized duration: {duration}".format(duration=duration))

    with multiprocessing.Pool(100) as pool:
        print('Fetching the list of all availability zones...')
        zones = sum(
            pool.starmap(fetch_zones, [(x, ) for x in config.aws_regions]), [])
        print('Querying spot price in each zone...')
        results = pool.starmap(fetch_zone_prices,
                               [(instance_type, zone, duration)
                                for zone in zones])

        price_list = []

        for zone, prices, timestamps in results:
            if len(prices) > 0:
                sorted_prices_ts = sorted(zip(prices, timestamps),
                                          key=lambda x: x[1])
                cur_time = datetime.datetime.now(
                    tz=sorted_prices_ts[0][1].tzinfo)
                sorted_prices, sorted_ts = [
                    np.asarray(x) for x in zip(*sorted_prices_ts)
                ]
                cutoff = cur_time - datetime.timedelta(seconds=duration)

                valid_ids = np.where(np.asarray(sorted_ts) > cutoff)[0]
                if len(valid_ids) == 0:
                    first_idx = 0
                else:
                    first_idx = max(0, valid_ids[0] - 1)

                max_price = max(sorted_prices[first_idx:])

                price_list.append((zone, max_price))

        print("Spot pricing information for instance type {type}".format(
            type=instance_type))

        list_string = ''
        for zone, price in sorted(price_list, key=lambda x: x[1]):
            print("Zone: {zone}, Max Price: {price}".format(zone=zone,
                                                            price=price))
            list_string += "'{}', ".format(zone)
        print(list_string)
예제 #8
0
def _copy_policy_params(job):
    config = get_cloudexec_config()
    for instance in get_all_instances():
        exp_name = get_tag_value(instance, 'ExpName')
        if exp_name == job:
            ip_addr = instance['PublicIpAddress']
            exp_group = get_tag_value(instance, 'ExpGroup')
            key_name = config.aws_key_pairs[instance['Region']]
            key_path = local_ec2_key_pair_path(key_name)
            remote_snapshots_path = os.path.join(
                config.ec2_project_root,
                "data/local/{exp_group}/{exp_name}/snapshots".format(
                    exp_group=exp_group, exp_name=exp_name))
            ssh_prefix = [
                "ssh",
                "-oStrictHostKeyChecking=no",
                "-oConnectTimeout=10",
                "-i",
                key_path,
                "ubuntu@{ip}".format(ip=ip_addr),
            ]

            ls_command = ssh_prefix + ["ls " + remote_snapshots_path]
            try:
                pkl_files = subprocess.check_output(ls_command)
            except subprocess.CalledProcessError:
                print(
                    "The snapshots folder does not exist yet. If the experiment is just launched, wait till the "
                    "first snapshot becomes available.")
                exit(0)

            pkl_files = pkl_files.decode().splitlines()

            if 'latest.pkl' in pkl_files:
                pkl_file = 'latest.pkl'
            else:
                pkl_file = sorted(pkl_files,
                                  key=lambda x: int(x.split('.')[0]))[-1]

            remote_pkl_path = os.path.join(remote_snapshots_path, pkl_file)

            copy_command = [
                "ssh", "-oStrictHostKeyChecking=no", "-oConnectTimeout=10",
                "-i", key_path, "ubuntu@{ip}".format(ip=ip_addr),
                "cp {remote_path} /tmp/params.pkl".format(
                    remote_path=remote_pkl_path)
            ]
            print(" ".join(copy_command))
            subprocess.check_call(copy_command)
            local_exp_path = os.path.join(
                get_project_root(),
                "data/s3/{exp_group}/{exp_name}".format(exp_group=exp_group,
                                                        exp_name=exp_name))
            local_pkl_path = os.path.join(local_exp_path, "snapshots",
                                          pkl_file)
            os.makedirs(os.path.dirname(local_pkl_path), exist_ok=True)
            command = [
                "scp",
                "-oStrictHostKeyChecking=no",
                "-oConnectTimeout=10",
                "-i",
                key_path,
                "ubuntu@{ip}:/tmp/params.pkl".format(ip=ip_addr),
                local_pkl_path,
            ]
            print(" ".join(command))
            subprocess.check_call(command)
            return local_exp_path
    return False
예제 #9
0
def main():
    import cloudexec
    import boto3
    import botocore.exceptions
    import os
    import subprocess
    config = cloudexec.get_cloudexec_config()

    assert len({
        config.attendee_id,
        config.ec2_instance_label,
        config.s3_bucket_root
    }) == 1, "attendee_id, ec2_instance_label, s3_bucket_root should have the same value"

    print("Testing attendee_id, aws_access_key, and aws_access_secret...")

    iam_client = boto3.client(
        "iam",
        region_name=config.aws_regions[0],
        aws_access_key_id=config.aws_access_key,
        aws_secret_access_key=config.aws_access_secret,
    )
    try:
        iam_client.list_access_keys(UserName=config.attendee_id)
    except botocore.exceptions.ClientError as e:
        if e.response['Error']['Code'] == 'InvalidClientTokenId':
            print("aws_access_key is not set properly!")
            exit()
        elif e.response['Error']['Code'] == 'SignatureDoesNotMatch':
            print("aws_access_secret is not set properly!")
            exit()
        elif e.response['Error']['Code'] == 'AccessDenied':
            print("attendee_id is not set properly!")
            exit()
        else:
            raise e

    # Check if key pair exists

    for region in config.aws_regions:
        print("Checking key pair in region %s" % region)
        if region not in config.aws_key_pairs:
            print("Key pair in region %s is not set properly!" % region)
            exit()
        key_pair_name = config.aws_key_pairs[region]
        key_pair_path = cloudexec.local_ec2_key_pair_path(key_pair_name)
        if not os.path.exists(key_pair_path):
            print("Missing local key pair file at %s" % key_pair_path)
            exit()
        ec2_client = boto3.client(
            "ec2",
            region_name=region,
            aws_access_key_id=config.aws_access_key,
            aws_secret_access_key=config.aws_access_secret,
        )
        try:
            response = ec2_client.describe_key_pairs(
                KeyNames=[config.aws_key_pairs[region]]
            )
        except botocore.exceptions.ClientError as e:
            if e.response['Error']['Code'] == 'InvalidKeyPair.NotFound':
                print("Key pair in region %s is not set properly!" % region)
                exit()
            else:
                raise e
        remote_fingerprint = response['KeyPairs'][0]['KeyFingerprint']

        # Get local key fingerprint

        ps = subprocess.Popen(
            ["openssl", "pkcs8", "-in", key_pair_path,
                "-nocrypt", "-topk8", "-outform", "DER"],
            stdout=subprocess.PIPE
        )
        local_fingerprint = subprocess.check_output(
            ["openssl", "sha1", "-c"], stdin=ps.stdout)
        # Strip irrelevant information
        local_fingerprint = local_fingerprint.decode().split('= ')[-1][:-1]

        if remote_fingerprint != local_fingerprint:
            print("Local key pair file does not match EC2 record!")
            exit()

    print("Your EC2 configuration has passed all checks!")
예제 #10
0
#!/usr/bin/env python

import cloudexec
import os
import argparse
import subprocess

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('folder', type=str)
    parser.add_argument('--all', action='store_true', default=False)
    args = parser.parse_args()
    remote_dir = "s3://{bucket}/{bucket_root}/experiments".format(
        bucket=cloudexec.get_cloudexec_config().s3_bucket,
        bucket_root=cloudexec.get_cloudexec_config().s3_bucket_root)
    local_dir = os.path.join(cloudexec.get_project_root(), "data", "s3")
    if args.folder:
        remote_dir = os.path.join(remote_dir, args.folder)
        local_dir = os.path.join(local_dir, args.folder)
    s3_env = dict(
        os.environ,
        AWS_ACCESS_KEY_ID=cloudexec.get_cloudexec_config().aws_access_key,
        AWS_SECRET_ACCESS_KEY=cloudexec.get_cloudexec_config().
        aws_access_secret,
        AWS_REGION=cloudexec.get_cloudexec_config().aws_s3_region,
    )
    if not args.all:
        command = ("""
            aws s3 sync --exclude '*' {s3_periodic_sync_include_flags} --content-type "UTF-8" {remote_dir} {local_dir} 
        """.format(local_dir=local_dir,
                   remote_dir=remote_dir,