Example #1
0
	def __init__(self,config_dict):
		# Create boto3 session for getting cloudwatch metrics
		session = Session(aws_access_key_id=config_dict['redshift_connection']['aws_access_key_id']
			,aws_secret_access_key=config_dict['redshift_connection']['aws_secret_access_key']
			,region_name=config_dict['redshift_connection']['region_name'])
		self.cw = session.client('cloudwatch')
		self.name_space = 'AWS/Redshift'
		self.metric_name = ['CPUUtilization',
				  'NetworkReceiveThroughput',
				  'NetworkTransmitThroughput',
				  'PercentageDiskSpaceUsed',
				  'ReadIOPS',
				  'ReadLatency',
				  'ReadThroughput',
				  'WriteIOPS',
				  'WriteLatency',
				  'WriteThroughput']
		self.period = 60
		self.statistics = ['Average']
		self.unit = ['Percent',
			'Bytes/Second',
			'Bytes/Second',
			'Percent',
			'Count/Second',
			'Seconds',
			'Bytes/Second',
			'Count/Second',
			'Seconds',
			'Bytes/Second']
		self.log_identifier = 'cw_metrics'
		self.cluster_name = config_dict['redshift_connection']['cluster_name']
		self.num_nodes = config_dict['redshift_connection']['num_nodes_cluster']
		self.post_db = PostDB(db_queue=None,database_config=config_dict['database'])
Example #2
0
    def test_get_available_partitions(self):
        bc_session = mock.Mock()
        bc_session.get_available_partitions.return_value = ['foo']
        session = Session(botocore_session=bc_session)

        partitions = session.get_available_partitions()
        self.assertEqual(partitions, ['foo'])
def main():

    if len(sys.argv) != 3:
        sys_exit(1)

    accepted_param = ['start', 'history', 'show', 'stop']
    valid_param = False

    for arg in sys.argv:
        if arg in accepted_param:
            valid_param = True
            break

    if not valid_param:
        sys_exit(1)

    default_profile='aater-flux7'
    spot_fleet_id = sys.argv[2]
    session = Session(profile_name=default_profile)
    client = session.client('ec2')
    resource = DemoEc2(client)
    ret = False

    if sys.argv[1] == 'start':
        ret = resource.create_aws_spot_fleet()
    elif sys.argv[1] == 'show':
        ret = resource.describe_aws_spot_fleet(spot_fleet_id)
    elif sys.argv[1] == 'history':
        ret = resource.history_aws_spot_fleet(spot_fleet_id)
    elif sys.argv[1] == 'stop':
        ret = resource.terminate_aws_spot_fleet(spot_fleet_id)

    print ret
Example #4
0
    def __init__(self, access_key, access_secret, region, bucket='ypanbucket'):
        session = Session(aws_access_key_id=access_key,
                            aws_secret_access_key=access_secret,
                            region_name=region)

        self.client = session.client('s3')
        self.bucket = bucket
    def reducer_final(self):

        import boto3
        import decimal

        # http://github.com/boto/boto/issues/1531
        from boto3.dynamodb.types import DYNAMODB_CONTEXT
        DYNAMODB_CONTEXT.traps[decimal.Inexact] = 0
        DYNAMODB_CONTEXT.traps[decimal.Rounded] = 0

        from boto3.session import Session

        session = Session(aws_access_key_id=AWS_ACCESS_KEY_ID,
                          aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
                          region_name='us-east-1')

        dynamodb = session.resource('dynamodb', region_name='us-east-1', endpoint_url="http://{}:8000".format(DYNAMODB_ENDPOINT_URL))
        table = dynamodb.Table('Delay2_1')

        for k in self.to_insert:
            airport, carrier = k
            delay = self.to_insert[k]
            #print carrier, delay, self.to_insert[k]

            try:
                response = table.put_item(
                    Item={
                        'delay': decimal.Decimal(delay),
                        'origin_airport': airport,
                        'carrier': carrier
                    }
                )
            except Exception as e:
                print e
Example #6
0
 def __init__(self, cfg):
     self.cfg = cfg
     aws = Session(aws_access_key_id=self.cfg.aws_key,
                   aws_secret_access_key=self.cfg.aws_secret,
                   region_name=self.cfg.aws_region)
     self.s3_client = aws.client('s3')
     self.bucket = aws.resource('s3').create_bucket(Bucket=self.cfg.s3_bucket)
def put_from_manifest(
        s3_bucket, s3_connection_host, s3_ssenc, s3_base_path,
        aws_access_key_id, aws_secret_access_key, manifest,
        bufsize, compress_data, concurrency=None, incremental_backups=False):
    """
    Uploads files listed in a manifest to amazon S3
    to support larger than 5GB files multipart upload is used (chunks of 60MB)
    files are uploaded compressed with lzop, the .lzo suffix is appended
    """
    bucket = get_bucket(
        s3_bucket, aws_access_key_id,
        aws_secret_access_key, s3_connection_host)

    # Create a boto3 session
    session = Session(aws_access_key_id = aws_access_key_id, aws_secret_access_key = aws_secret_access_key, region_name='us-east-1')
    client = session.client('s3')
    event_system = client.meta.events
    config = TransferConfig(
        multipart_threshold = MULTI_PART_UPLOAD_THRESHOLD,
        max_concurrency=4)
    transfer = S3Transfer(client, config)
    boto3.set_stream_logger('botocore', logging.INFO)

    manifest_fp = open(manifest, 'r')
    files = manifest_fp.read().splitlines()
    for f in files:
        file_path = s3_base_path + f
        print("boto3, upload file {0} to {1}: {2}".format(f, s3_bucket, file_path))
        transfer.upload_file(f, s3_bucket, file_path)
Example #8
0
def aggregate(queue, session_args, timeout=60, retry=3):
    signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore the KeyboardInterrupt
    total_bytes = 0
    total_seconds = 0
    received_count = 0
    errors = []

    session = Session(**session_args)
    sqs = session.resource('sqs')
    queue = sqs.Queue(queue)

    start = time.time()
    current = time.time()
    #print("Aggregating results")
    while retry > 0:# or int(current - start) < timeout:
        msgs = queue.receive_messages(WaitTimeSeconds=20, MaxNumberOfMessages=10)
        if len(msgs) == 0:
            #print("\tretry")
            retry -= 1
            continue
        print('.', end='', flush=True)
        for msg in msgs:
            queue.delete_messages(Entries = [{'Id':'X', 'ReceiptHandle': msg.receipt_handle}])
            received_count += 1
            msg = json.loads(msg.body)
            if 'error' in msg:
                errors.append(msg['error'])
            else:
                total_bytes += msg['bytes']
                total_seconds += (msg['read_stop'] - msg['req_start'])
        current = time.time()
    return (received_count, (total_bytes, total_seconds), errors)
Example #9
0
def _send_mail(to, subject, body, email_format='Text'):
    if settings.DEBUG:
        print((to, subject, body))
    session = Session(aws_access_key_id=settings.SES_ACCESS_ID,
                      aws_secret_access_key=settings.SES_SECRET_KEY,
                      region_name='us-east-1')
    conn = session.client('ses')
    resp = conn.send_email(
        Source=settings.SENDER_EMAIL,
        Destination={'ToAddresses': [to]},
        Message={
            'Subject': {
                'Data': subject,
            },
            'Body': {
                email_format: {
                    'Data': body,
                },
            },
        },
        ReplyToAddresses=[settings.SUPPORT_EMAIL],
        ReturnPath=settings.ADMINS[0][1]
    )
    if not resp.get('MessageId'):
        rollbar.report_message('Got bad response from SES: %s' % repr(resp), 'error')
    def __init__(self, pipeline_id, region=None, access_key_id=None, secret_access_key=None):
        self.pipeline_id = pipeline_id

        if not region:
            region = getattr(settings, 'AWS_REGION', None)
        self.aws_region = region

        if not access_key_id:
            access_key_id = getattr(settings, 'AWS_ACCESS_KEY_ID', None)
        self.aws_access_key_id = access_key_id

        if not secret_access_key:
            secret_access_key = getattr(settings, 'AWS_SECRET_ACCESS_KEY', None)
        self.aws_secret_access_key = secret_access_key

        if self.aws_access_key_id is None:
            assert False, 'Please provide AWS_ACCESS_KEY_ID'

        if self.aws_secret_access_key is None:
            assert False, 'Please provide AWS_SECRET_ACCESS_KEY'

        if self.aws_region is None:
            assert False, 'Please provide AWS_REGION'

        boto_session = Session(
            aws_access_key_id=self.aws_access_key_id,
            aws_secret_access_key=self.aws_secret_access_key,
            region_name=self.aws_region,
        )
        self.client = boto_session.client('elastictranscoder')
def getAwsAmi(aws_key, aws_secret, region):
    if aws_key == '':
        session = Session(region_name=region)
    else:
        session = Session(region_name=region, aws_access_key_id=aws_key, aws_secret_access_key=aws_secret)

    ec2 = session.resource('ec2')
    image_list = ec2.images.filter(Owners=['amazon'],
                                   Filters
                                   = [ {'Name': 'virtualization-type',
                                        'Values': ['hvm']},
                                       {'Name': 'architecture',
                                        'Values': ['x86_64']},
                                       {'Name': 'root-device-type',
                                        'Values': ['ebs']},
                                      ]
                                   )

    threeMonthAgo = datetime.datetime.today() - datetime.timedelta(3*365/12)
    image_list = [m for m in image_list if imageFilter(m, threeMonthAgo)]
    image_list = sorted(image_list, key= getImageCreationDate, reverse= True)

    #for image in image_list:
    #    print json.dumps(image, cls=JSONEncoder)

    #get standard image
    image_list = [m for m in image_list if ("-nat-" not in m.name)]
    return image_list[0]
Example #12
0
def main(argv):
    args = get_args().parse_args(argv)
    session = Session(profile_name=args.profile, region_name=args.region)
    ec2 = session.resource("ec2")
    s3 = session.resource("s3")
    cff = session.resource("cloudformation")
    try:
        stack = get_stack(args.stack_name, cff)
        if stack is None:
            print "Stack in no longer active"
            return 0
        stack_vars = get_stack_outputvars(stack, ec2)
        cleanup_ec2(stack, stack_vars, ec2, args.yes, exclude_nat=True)
        cleanup_s3(stack, stack_vars, s3, args.yes)
        if args.remove_stack is True:
            print "Removing stack"
            if confirm_oprn("Stack "+args.stack_name, 1, args.yes) is False:
                return -1
            cleanup_ec2(stack, stack_vars, ec2, args.yes, exclude_nat=False)
            stack.delete()
        else:
            print "Stack was not deleted."
            print "Re-run the same command with --remove-stack"
    except botocore.exceptions.NoCredentialsError as ex:
        print ex
        print "Missing ~/.aws/credentials directory?"
        print "http://boto3.readthedocs.org/en/latest/guide/configuration.html"
        return -1

    return 0
 def __init__(self,config_area):
     self.aws_region = aws_config.get(config_area, "aws_region")
     self.aws_access_key_id = aws_config.get(config_area, "aws_access_key_id")
     self.aws_secret_access_key = aws_config.get(config_area, "aws_secret_access_key")
     session = Session(aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, region_name=self.aws_region)
     self.ec2 = session.resource('ec2',config=Config(signature_version='s3v4'))
     self.ec2_client = session.client('ec2',config=Config(signature_version='s3v4'))
Example #14
0
    def create(self, request):

        serializer = self.get_serializer(data=request.data)

        upload_files_user = request.FILES.get('upload_image', None)

        if serializer.is_valid():

            if upload_files_user is not None:

                session = Session(aws_access_key_id='AKIAJYDV7TEBJS6JWEEQ',
                  aws_secret_access_key='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')

                s3 = session.resource('s3')
                bucket = s3.Bucket('walladog')
                key_file = request.data['username'] + ".jpeg"
                bucket.put_object(ACL='public-read', Key=key_file, Body=upload_files_user, ContentType='image/jpeg')
                photo_url = "https://s3.amazonaws.com/walladog/" + key_file
                new_user = serializer.save(avatar_url=photo_url)

            else:

                new_user = serializer.save()

            serialize_bnew_user = UserListSerializer(new_user)
            return Response(serialize_bnew_user.data, status=status.HTTP_201_CREATED)
        else:
            return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
Example #15
0
def push2aws(logger, filename=None):
    if not filename:
        filename = jsonfile
    logger.info('pushing filename[%s]' % filename)

    try:
        with open(filename, 'r') as html_file:
            logger.info('Reading [%s]' % filename)
            html_source = html_file.read()
    except Exception as e:
        logger.error('Exception when opening file for jobcard[%s] - EXCEPT[%s:%s]' % (jobcard_no, type(e), e))
        raise e
        
    with open(filename, 'rb') as filehandle:
        content = filehandle.read()
    cloud_filename='media/temp/customerJSON/%s' % filename
    session = Session(aws_access_key_id=LIBTECH_AWS_ACCESS_KEY_ID,
                                    aws_secret_access_key=LIBTECH_AWS_SECRET_ACCESS_KEY)
    s3 = session.resource('s3',config=Config(signature_version='s3v4'))
    s3.Bucket(AWS_STORAGE_BUCKET_NAME).put_object(ACL='public-read',Key=cloud_filename, Body=content, ContentType='application/json')
    
    public_url='https://s3.ap-south-1.amazonaws.com/libtech-nrega1/%s' % cloud_filename
    logger.info('File written on AWS[%s]' % public_url)

    return 'SUCCESS'
def main():
    """Main method for setup.

    Setup command line parsing, the AWS connection and call the functions
    containing the actual logic.
    """
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                        level=logging.INFO)
    logging.getLogger('botocore').setLevel(logging.CRITICAL)

    parser = argparse.ArgumentParser(description='Script to automate \
                                     snapshotting of EBS volumes')
    parser.add_argument('--aws-access-key-id',
                        dest='aws_access_key_id',
                        help='Specify a value here if you want to use a '
                        'different AWS_ACCESS_KEY_ID than configured in the '
                        'AWS CLI.')
    parser.add_argument('--aws-secret-access-key',
                        dest='aws_secret_access_key',
                        help='Specify a value here if you want to use a '
                        'different AWS_SECRET_ACCESS_KEY than configured in '
                        'the AWS CLI.')
    parser.add_argument('--profile',
                        dest='profile_name',
                        help='The AWS CLI profile to use. Defaults to the '
                        'default profile.')
    parser.add_argument('--region',
                        dest='region_name', default='us-east-1',
                        help='The AWS region to connect to. Defaults to the '
                        'one configured for the AWS CLI.')
    parser.add_argument('-n', '--num-backups', dest='num_backups', type=int,
                        default=14,
                        help='The number of backups for each volume to keep')
    parser.add_argument('-t', '--tag', dest='tag', default='Lifecycle:legacy',
                        help='Key and value (separated by a colon) of a tag '
                        'attached to instances whose EBS volumes should be '
                        'backed up')
    args = parser.parse_args()

    session_args = {key: value for key, value in list(vars(args).items())
                    if key in ['aws_access_key_id',
                               'aws_secret_access_key',
                               'profile_name',
                               'region_name']}
    try:
        session = Session(**session_args)
        client = session.client('ec2')
    except BotoCoreError as exc:
        logging.error("Connecting to the EC2 API failed: %s", exc)
        sys.exit(1)

    tag_key_value = args.tag.split(':')
    if len(tag_key_value) != 2:
        logging.error("Given tag key value: \"%s\" is invalid.", args.tag)
        sys.exit(1)
    tag_key = tag_key_value[0]
    tag_value = tag_key_value[1]

    make_snapshots(client, tag_key, tag_value)
    delete_old_snapshots(client, tag_key, tag_value, args.num_backups)
Example #17
0
    def update(self, request, pk):

        user = get_object_or_404(UserDetail, pk=pk)
        self.check_object_permissions(request, user)
        serializer = self.get_serializer(instance=user, data=request.data, partial=True)

        upload_files_user = request.FILES.get('upload_image', None)

        if serializer.is_valid():

            update_user = serializer.save()

            if upload_files_user is not None:

                session = Session(aws_access_key_id='AKIAJYDV7TEBJS6JWEEQ',
                  aws_secret_access_key='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')

                s3 = session.resource('s3')
                bucket = s3.Bucket('walladog')
                key_file = user.user.username + ".jpeg"
                bucket.put_object(ACL='public-read', Key=key_file, Body=upload_files_user, ContentType='image/jpeg')

            serialize_bnew_user = UserListSerializer(update_user)
            return Response(serialize_bnew_user.data, status=status.HTTP_200_OK)
        else:
            return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
    def handle(self, *args, **options):
        AWS_ACCESS_KEY_ID = os.environ.get('AWS_S3_ACCESS_KEY_ID')
        AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_S3_SECRET_ACCESS_KEY')
        AWS_REGION_NAME = os.environ.get('AWS_S3_REGION_NAME')
        AWS_BUCKET_NAME = os.environ.get('AWS_S3_BUCKET_NAME')

        session = Session(aws_access_key_id=AWS_ACCESS_KEY_ID,
                          aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
                          region_name=AWS_REGION_NAME)

        s3 = session.resource('s3')

        bucket = s3.Bucket(AWS_BUCKET_NAME)

        timestamp = datetime.utcnow().strftime('%Y_%m_%d_%H%M_UTC')
        db_file = "pgbackup_{}.dump".format(timestamp)
        db_host = settings.DATABASES['default']['HOST']
        db_port = settings.DATABASES['default']['PORT']
        db_name = settings.DATABASES['default']['NAME']
        db_user = settings.DATABASES['default']['USER']

        # See command definitions at http://www.postgresql.org/docs/9.4/static/app-pgdump.html
        pg_dump_command = "pg_dump --host={host} --port={port} --user={user} --format=c -O -x --file={file_name} {name}".format(
            host=db_host,
            port=db_port,
            user=db_user,
            file_name=db_file,
            name=db_name)
        self.stdout.write("Enter {}'s psql password".format(db_user))
        os.system(pg_dump_command)

        bucket.upload_file(db_file, db_file)
        os.system("rm {}".format(db_file))
        self.stdout.write("{} successfully uploaded to AWS S3".format(db_file))
Example #19
0
def locate_ami(aws_config):
    def contains(x, ys):
        for y in ys:
            if y not in x:
                return False
        return True

    with open(aws_config) as fh:
        cred = json.load(fh)
        session = Session(aws_access_key_id = cred["aws_access_key"],
                          aws_secret_access_key = cred["aws_secret_key"],
                          region_name = 'us-east-1')

        client = session.client('ec2')
        response = client.describe_images(Filters=[
                        {"Name": "owner-id", "Values": ["099720109477"]},
                        {"Name": "virtualization-type", "Values": ["hvm"]},
                        {"Name": "root-device-type", "Values": ["ebs"]},
                        {"Name": "architecture", "Values": ["x86_64"]},
                        #{"Name": "platform", "Values": ["Ubuntu"]},
                        #{"Name": "name", "Values": ["hvm-ssd"]},
                        #{"Name": "name", "Values": ["14.04"]},
                   ])

        images = response['Images']
        images = [i for i in images if contains(i['Name'], ('hvm-ssd', '14.04', 'server'))]
        images.sort(key=lambda x: x["CreationDate"], reverse=True)

        if len(images) == 0:
            print("Error: could not locate base AMI, exiting ....")
            sys.exit(1)

        print("Using {}".format(images[0]['Name']))
        return images[0]['ImageId']
def downloadFromAWS(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_S3_BUCKET_NAME, file_uuid, original_file_name):
  
  session = Session(AWS_ACCESS_KEY_ID,
                    AWS_SECRET_ACCESS_KEY)
  
  s3 = session.resource('s3')
  s3.Bucket(AWS_S3_BUCKET_NAME).download(Key=str(file_uuid), Filename=original_file_name)
def uploadToAWS(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_S3_BUCKET_NAME, zip_file, file_uuid):
  
  session = Session(AWS_ACCESS_KEY_ID,
                    AWS_SECRET_ACCESS_KEY)
  s3 = session.resource('s3')
  data = open(secure_filename(zip_file.filename), 'rb')
  s3.Bucket(AWS_S3_BUCKET_NAME).put_object(Key=str(file_uuid), Body=data, ACL='public-read')
Example #22
0
    def init_app(self, app, conf_key=None):
        """
        :type app: flask.Flask
        :parm str conf_key: Key of flask config.
        """
        conf_key = conf_key or self.conf_key or 'FLASK_BOTO_SQS'
        self.conf_key = conf_key
        conf = app.config[conf_key]
        if not isinstance(conf, dict):
            raise TypeError("FLASK_BOTO_SQS conf should be dict")

        close_on_teardown = conf.pop('close_on_teardown', False)

        session = Session(aws_access_key_id=conf['aws_access_key_id'],
                          aws_secret_access_key=conf['aws_secret_access_key'],
                          region_name=conf['region'])
        sqs = session.resource('sqs')

        app.extensions.setdefault('botosqs', {})
        app.extensions['botosqs'][self] = sqs

        if close_on_teardown:
            @app.teardown_appcontext
            def close_connection(exc=None):
                #TODO: close connection
                pass
Example #23
0
def main(argv):
    args = get_args().parse_args(argv)
    session = Session(profile_name=args.profile, region_name=args.region)
    ec2 = session.resource("ec2")
    s3 = session.resource("s3")
    cff = session.resource("cloudformation")

    get_orphaned_resources(cff, ec2, s3, args)
Example #24
0
def upload_AWS(zip_file, file_uuid, metadata=None):
	session = Session(AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY)
	s3 = session.resource('s3')
	s3.Bucket(AWS_S3_BUCKET_NAME).put_object(
		Key=str(file_uuid),
		Body=zip_file,
		ACL='public-read',
		Metadata=metadata)
Example #25
0
    def test_get_available_services(self):
        bc_session = self.bc_session_cls.return_value

        session = Session()
        session.get_available_services()

        self.assertTrue(bc_session.get_available_services.called,
            'Botocore session get_available_services not called')
Example #26
0
    def __init__(self):
        session = Session(
            aws_access_key_id=app.config.get("S3_AWS_ACCESS_KEY_ID"),
            aws_secret_access_key=app.config.get("S3_AWS_SECRET_ACCESS_KEY"),
            region_name=app.config.get("S3_REGION"),
        )

        self._client = session.client("s3")
Example #27
0
    def handle(self, *args, **options):
        session = Session(
            aws_access_key_id=settings.S3_ACCESS_ID,
            aws_secret_access_key=settings.S3_SECRET_KEY,
            region_name=options["region"],
        )

        to_reprocess = []

        self.stdout.write("Processing bucket: %s" % settings.S3_LOGS_BUCKET)
        self.stdout.write("Downloading S3 manifest...")

        bucket = session.resource("s3").Bucket(name=settings.S3_LOGS_BUCKET)

        self.stdout.write("Analyzing bucket contents...")

        start_date = datetime.datetime.strptime(options["start"], "%Y-%m-%d")
        end_date = datetime.datetime.strptime(options["end"], "%Y-%m-%d")

        hits = 0
        for f in bucket.objects.all():
            hits += 1

            if hits % 500 == 0:
                self.stdout.write(" - Processed %d log listings..." % hits)

            filename = f.key.split("/")[-1]
            if not filename:
                continue

            # Ignore CF files for now
            if filename.endswith(".gz"):
                continue

            datestamp = "-".join(filename.split("-")[:-1])
            parsed_ds = datetime.datetime.strptime(datestamp, "%Y-%m-%d-%H-%M-%S")

            if parsed_ds < start_date or parsed_ds > end_date:
                continue

            to_reprocess.append(f.key)

        self.stdout.write("Finished analysis")
        self.stdout.write("%s logs need to be reprocessed" % len(to_reprocess))

        if not to_reprocess:
            return

        if options["run"]:
            lambda_client = session.client("lambda")
            for f in to_reprocess:
                blob = json.dumps(
                    {"Records": [{"s3": {"bucket": {"name": settings.S3_LOGS_BUCKET}, "object": {"key": f}}}]}
                )
                lambda_client.invoke(FunctionName=options["function"], InvocationType="Event", Payload=blob)
            self.stdout.write("Lambda invoked for each log file. See CloudWatch for output")
        else:
            self.stdout.write("No additional action was performed. Use --run to actually reprocess")
Example #28
0
 def checkCredentials(self):
     region = raw_input('Region for AWS S3(eg. us-east-1): ').strip()        
     keyid = raw_input('Access Key ID for AWS S3: ').strip()
     keysecret = getpass.getpass('Access Key Secret for AWS S3: ').strip()
     sess = Session(aws_access_key_id=keyid,
                      aws_secret_access_key=keysecret,
                    region_name=region)
     self.s3 = sess.resource('s3')
     self.s3bucket = self.s3.Bucket(self.bucket)
Example #29
0
    def test_get_available_resources(self):
        mock_bc_session = mock.Mock()
        loader = mock.Mock(spec=loaders.Loader)
        loader.list_available_services.return_value = ['foo', 'bar']
        mock_bc_session.get_component.return_value = loader
        session = Session(botocore_session=mock_bc_session)

        names = session.get_available_resources()
        self.assertEqual(names, ['foo', 'bar'])
def setup_s3_client(job_data):
  key_id = job_data['artifactCredentials']['accessKeyId']
  key_secret = job_data['artifactCredentials']['secretAccessKey']
  session_token = job_data['artifactCredentials']['sessionToken']
    
  session = Session(aws_access_key_id=key_id,
    aws_secret_access_key=key_secret,
    aws_session_token=session_token)
  return session.client('s3', config=botocore.client.Config(signature_version='s3v4'))
Example #31
0
from boto3.session import Session
import datetime

session = Session(aws_access_key_id='[your_key_id]',
                  aws_secret_access_key='[your_secret_key]')


def delete_from(resource_name, region, your_bucket_name, date):
    resource = session.resource(resource_name, region)
    my_bucket = resource.Bucket(your_bucket_name)
    results = []
    for obj in my_bucket.objects.all():
        if (obj.last_modified).replace(tzinfo=None) < date:
            results.append(
                my_bucket.delete_objects(
                    Delete={'Objects': [{
                        'Key': obj.key
                    }]}))
    return results


delete_from('ec2', 'us-west', '[YOUR_BUCKET]', datetime.datetime(2019, 12, 1))
Example #32
0
import os
from boto3.session import Session
import boto3

s3_bucket_name = os.getenv("s3_bucket_name")
user_table_name = os.getenv("user_table_name")
group_table_name = os.getenv("group_table_name")
session = Session(
    aws_access_key_id=os.getenv("aws_access_key_id"),
    aws_secret_access_key=os.getenv("aws_secret_access_key"),
    region_name=os.getenv("aws_region", "ap-southeast-1"),
)

base_url = os.getenv("base_url")

s3_client = session.client("s3")
ddb_client = session.client("dynamodb")
Example #33
0
        'django.security.*': {
            'handlers': ['console'],
            'level': DEBUG and 'DEBUG' or 'INFO',
        },
        'django.request': {
            'handlers': ['console'],
            'level': DEBUG and 'DEBUG' or 'INFO',
        },
    },
}

# Production logging
if ENV not in ['local', 'test', 'staging', 'preview']:
    # add AWS monitoring
    boto3_session = Session(aws_access_key_id=AWS_ACCESS_KEY_ID,
                            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
                            region_name=AWS_DEFAULT_REGION)

    LOGGING['formatters']['cloudwatch'] = {
        'format': '%(hostname)s %(name)-12s [%(levelname)-8s] %(message)s',
    }
    LOGGING['handlers']['watchtower'] = {
        'level': AWS_LOG_LEVEL,
        'class': 'watchtower.django.DjangoCloudWatchLogHandler',
        'boto3_session': boto3_session,
        'log_group': AWS_LOG_GROUP,
        'stream_name': AWS_LOG_STREAM,
        'filters': ['host_filter'],
        'formatter': 'cloudwatch',
    }
    LOGGING['loggers']['django.db.backends']['level'] = AWS_LOG_LEVEL
Example #34
0
                              aws_secret_access_key=self.ACCESS_KEY_PASSWD, region_name=self.REGION_NAME)
            s3 = session.resource("s3")
            bucket = s3.Bucket(self.BUCKET_NAME)

            filename = "{source}/{target}/{base_file}".format(source=self.request.form['source'],
                                                             target=list(self.request.form['target'].values()).pop(0),
                                                             base_file=self.request.files[form_name].filename)

            """
            key 명명 규칙은 source/target/filename
            """

            s3_object = bucket.put_object(Key=filename, Body=request.files[form_name])
            end_point = s3_object.meta.client._endpoint.host[s3_object.meta.client._endpoint.host.find("s3"):]
            s3_url = f"https://{self.BUCKET_NAME}.{end_point}/{filename}"
            print(s3_url)
            return s3_url
        except Exception as e:
            print(e)

if __name__ == '__main__':
    s3 = S3Client()
    import boto3
    from boto3.session import Session

    session = Session(aws_access_key_id=s3.ACCESS_KEY_ID,
                      aws_secret_access_key= s3.ACCESS_KEY_PASSWD, region_name= s3.REGION_NAME)
    s3_client = session.resource("s3")
    bucket = s3_client.Bucket(s3.BUCKET_NAME)
    bucket.download_file(bucket=s3.BUCKET_NAME,filename="tester/01029209599/kira.png")
Example #35
0
def main():
    """Initializes main script from command-line call to generate
    single-subject or multi-subject workflow(s)"""
    import os
    import gc
    import sys
    import json
    from pynets.core.utils import build_args_from_config
    import itertools
    from types import SimpleNamespace
    import pkg_resources
    from pynets.core.utils import flatten
    from pynets.cli.pynets_run import build_workflow
    from multiprocessing import set_start_method, Process, Manager
    from colorama import Fore, Style

    try:
        import pynets
    except ImportError:
        print(
            "PyNets not installed! Ensure that you are referencing the correct"
            " site-packages and using Python3.6+"
        )

    if len(sys.argv) < 1:
        print("\nMissing command-line inputs! See help options with the -h"
              " flag.\n")
        sys.exit(1)

    print(f"{Fore.LIGHTBLUE_EX}\nBIDS API\n")

    print(Style.RESET_ALL)

    print(f"{Fore.LIGHTGREEN_EX}Obtaining Derivatives Layout...")

    print(Style.RESET_ALL)

    modalities = ["func", "dwi"]
    space = 'T1w'

    bids_args = get_bids_parser().parse_args()
    participant_label = bids_args.participant_label
    session_label = bids_args.session_label
    run = bids_args.run_label
    if isinstance(run, list):
        run = str(run[0]).zfill(2)
    modality = bids_args.modality
    bids_config = bids_args.config
    analysis_level = bids_args.analysis_level
    clean = bids_args.clean

    if analysis_level == "group" and participant_label is not None:
        raise ValueError(
            "Error: You have indicated a group analysis level run, but"
            " specified a participant label!"
        )

    if analysis_level == "participant" and participant_label is None:
        raise ValueError(
            "Error: You have indicated a participant analysis level run, but"
            " not specified a participant "
            "label!")

    if bids_config:
        with open(bids_config, "r") as stream:
            arg_dict = json.load(stream)
    else:
        with open(
            pkg_resources.resource_filename("pynets",
                                            "config/bids_config.json"),
            "r",
        ) as stream:
            arg_dict = json.load(stream)
        stream.close()

    # S3
    # Primary inputs
    s3 = bids_args.bids_dir.startswith("s3://")

    if not s3:
        bids_dir = bids_args.bids_dir

    # secondary inputs
    sec_s3_objs = []
    if isinstance(bids_args.ua, list):
        for i in bids_args.ua:
            if i.startswith("s3://"):
                print("Downloading user atlas: ", i, " from S3...")
                sec_s3_objs.append(i)
    if isinstance(bids_args.cm, list):
        for i in bids_args.cm:
            if i.startswith("s3://"):
                print("Downloading clustering mask: ", i, " from S3...")
                sec_s3_objs.append(i)
    if isinstance(bids_args.roi, list):
        for i in bids_args.roi:
            if i.startswith("s3://"):
                print("Downloading ROI mask: ", i, " from S3...")
                sec_s3_objs.append(i)
    if isinstance(bids_args.way, list):
        for i in bids_args.way:
            if i.startswith("s3://"):
                print("Downloading tractography waymask: ", i, " from S3...")
                sec_s3_objs.append(i)

    if bids_args.ref:
        if bids_args.ref.startswith("s3://"):
            print(
                "Downloading atlas labeling reference file: ",
                bids_args.ref,
                " from S3...",
            )
            sec_s3_objs.append(bids_args.ref)

    if s3 or len(sec_s3_objs) > 0:
        from boto3.session import Session
        from pynets.core import cloud_utils
        from pynets.core.utils import as_directory

        home = os.path.expanduser("~")
        creds = bool(cloud_utils.get_credentials())

        if s3:
            buck, remo = cloud_utils.parse_path(bids_args.bids_dir)
            os.makedirs(f"{home}/.pynets", exist_ok=True)
            os.makedirs(f"{home}/.pynets/input", exist_ok=True)
            os.makedirs(f"{home}/.pynets/output", exist_ok=True)
            bids_dir = as_directory(f"{home}/.pynets/input", remove=False)
            if (not creds) and bids_args.push_location:
                raise AttributeError(
                    """No AWS credentials found, but `--push_location` flag
                     called. Pushing will most likely fail.""")
            else:
                output_dir = as_directory(
                    f"{home}/.pynets/output", remove=False)

            # Get S3 input data if needed
            if analysis_level == "participant":
                for partic, ses in list(
                    itertools.product(participant_label, session_label)
                ):
                    if ses is not None:
                        info = "sub-" + partic + "/ses-" + ses
                    elif ses is None:
                        info = "sub-" + partic
                    cloud_utils.s3_get_data(
                        buck, remo, bids_dir, modality, info=info)
            elif analysis_level == "group":
                if len(session_label) > 1 and session_label[0] != "None":
                    for ses in session_label:
                        info = "ses-" + ses
                        cloud_utils.s3_get_data(
                            buck, remo, bids_dir, modality, info=info
                        )
                else:
                    cloud_utils.s3_get_data(buck, remo, bids_dir, modality)

        if len(sec_s3_objs) > 0:
            [access_key, secret_key] = cloud_utils.get_credentials()

            session = Session(
                aws_access_key_id=access_key, aws_secret_access_key=secret_key
            )

            s3_r = session.resource("s3")
            s3_c = cloud_utils.s3_client(service="s3")
            sec_dir = as_directory(
                home + "/.pynets/secondary_files", remove=False)
            for s3_obj in [i for i in sec_s3_objs if i is not None]:
                buck, remo = cloud_utils.parse_path(s3_obj)
                s3_c.download_file(
                    buck, remo, f"{sec_dir}/{os.path.basename(s3_obj)}")

            if isinstance(bids_args.ua, list):
                local_ua = bids_args.ua.copy()
                for i in local_ua:
                    if i.startswith("s3://"):
                        local_ua[local_ua.index(
                            i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.ua = local_ua
            if isinstance(bids_args.cm, list):
                local_cm = bids_args.cm.copy()
                for i in bids_args.cm:
                    if i.startswith("s3://"):
                        local_cm[local_cm.index(
                            i)] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.cm = local_cm
            if isinstance(bids_args.roi, list):
                local_roi = bids_args.roi.copy()
                for i in bids_args.roi:
                    if i.startswith("s3://"):
                        local_roi[
                            local_roi.index(i)
                        ] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.roi = local_roi
            if isinstance(bids_args.way, list):
                local_way = bids_args.way.copy()
                for i in bids_args.way:
                    if i.startswith("s3://"):
                        local_way[
                            local_way.index(i)
                        ] = f"{sec_dir}/{os.path.basename(i)}"
                bids_args.way = local_way

            if bids_args.ref:
                if bids_args.ref.startswith("s3://"):
                    bids_args.ref = f"{sec_dir}/" \
                                    f"{os.path.basename(bids_args.ref)}"
    else:
        output_dir = bids_args.output_dir
        if output_dir is None:
            raise ValueError("Must specify an output directory")

    intermodal_dict = {
        k: []
        for k in [
            "funcs",
            "confs",
            "dwis",
            "bvals",
            "bvecs",
            "anats",
            "masks",
            "subjs",
            "seshs",
        ]
    }
    if analysis_level == "group":
        if len(modality) > 1:
            i = 0
            for mod_ in modality:
                outs = sweep_directory(
                    bids_dir,
                    modality=mod_,
                    space=space,
                    sesh=session_label,
                    run=run
                )
                if mod_ == "func":
                    if i == 0:
                        funcs, confs, _, _, _, anats, masks, subjs, seshs =\
                            outs
                    else:
                        funcs, confs, _, _, _, _, _, _, _ = outs
                    intermodal_dict["funcs"].append(funcs)
                    intermodal_dict["confs"].append(confs)
                elif mod_ == "dwi":
                    if i == 0:
                        _, _, dwis, bvals, bvecs, anats, masks, subjs, seshs =\
                            outs
                    else:
                        _, _, dwis, bvals, bvecs, _, _, _, _ = outs
                    intermodal_dict["dwis"].append(dwis)
                    intermodal_dict["bvals"].append(bvals)
                    intermodal_dict["bvecs"].append(bvecs)
                intermodal_dict["anats"].append(anats)
                intermodal_dict["masks"].append(masks)
                intermodal_dict["subjs"].append(subjs)
                intermodal_dict["seshs"].append(seshs)
                i += 1
        else:
            intermodal_dict = None
            outs = sweep_directory(
                bids_dir,
                modality=modality[0],
                space=space,
                sesh=session_label,
                run=run
            )
            funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
    elif analysis_level == "participant":
        if len(modality) > 1:
            i = 0
            for mod_ in modality:
                outs = sweep_directory(
                    bids_dir,
                    modality=mod_,
                    space=space,
                    subj=participant_label,
                    sesh=session_label,
                    run=run
                )
                if mod_ == "func":
                    if i == 0:
                        funcs, confs, _, _, _, anats, masks, subjs, seshs =\
                            outs
                    else:
                        funcs, confs, _, _, _, _, _, _, _ = outs
                    intermodal_dict["funcs"].append(funcs)
                    intermodal_dict["confs"].append(confs)
                elif mod_ == "dwi":
                    if i == 0:
                        _, _, dwis, bvals, bvecs, anats, masks, subjs, seshs =\
                            outs
                    else:
                        _, _, dwis, bvals, bvecs, _, _, _, _ = outs
                    intermodal_dict["dwis"].append(dwis)
                    intermodal_dict["bvals"].append(bvals)
                    intermodal_dict["bvecs"].append(bvecs)
                intermodal_dict["anats"].append(anats)
                intermodal_dict["masks"].append(masks)
                intermodal_dict["subjs"].append(subjs)
                intermodal_dict["seshs"].append(seshs)
                i += 1
        else:
            intermodal_dict = None
            outs = sweep_directory(
                bids_dir,
                modality=modality[0],
                space=space,
                subj=participant_label,
                sesh=session_label,
                run=run
            )
            funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = outs
    else:
        raise ValueError(
            "Analysis level invalid. Must be `participant` or `group`. See"
            " --help."
        )

    if intermodal_dict:
        funcs, confs, dwis, bvals, bvecs, anats, masks, subjs, seshs = [
            list(set(list(flatten(i)))) for i in intermodal_dict.values()
        ]

    args_dict_all = build_args_from_config(modality, arg_dict)

    id_list = []
    for i in sorted(list(set(subjs))):
        for ses in sorted(list(set(seshs))):
            id_list.append(i + "_" + ses)

    args_dict_all["work"] = bids_args.work
    args_dict_all["output_dir"] = output_dir
    args_dict_all["plug"] = bids_args.plug
    args_dict_all["pm"] = bids_args.pm
    args_dict_all["v"] = bids_args.v
    args_dict_all["clean"] = bids_args.clean
    if funcs is not None:
        args_dict_all["func"] = sorted(funcs)
    else:
        args_dict_all["func"] = None
    if confs is not None:
        args_dict_all["conf"] = sorted(confs)
    else:
        args_dict_all["conf"] = None
    if dwis is not None:
        args_dict_all["dwi"] = sorted(dwis)
        args_dict_all["bval"] = sorted(bvals)
        args_dict_all["bvec"] = sorted(bvecs)
    else:
        args_dict_all["dwi"] = None
        args_dict_all["bval"] = None
        args_dict_all["bvec"] = None
    if anats is not None:
        args_dict_all["anat"] = sorted(anats)
    else:
        args_dict_all["anat"] = None
    if masks is not None:
        args_dict_all["m"] = sorted(masks)
    else:
        args_dict_all["m"] = None
    args_dict_all["g"] = None
    if ("dwi" in modality) and (bids_args.way is not None):
        args_dict_all["way"] = bids_args.way
    else:
        args_dict_all["way"] = None
    args_dict_all["id"] = id_list
    args_dict_all["ua"] = bids_args.ua
    args_dict_all["ref"] = bids_args.ref
    args_dict_all["roi"] = bids_args.roi
    if ("func" in modality) and (bids_args.cm is not None):
        args_dict_all["cm"] = bids_args.cm
    else:
        args_dict_all["cm"] = None

    # Mimic argparse with SimpleNamespace object
    args = SimpleNamespace(**args_dict_all)
    print(args)

    set_start_method("forkserver")
    with Manager() as mgr:
        retval = mgr.dict()
        p = Process(target=build_workflow, args=(args, retval))
        p.start()
        p.join()
        if p.is_alive():
            p.terminate()

        retcode = p.exitcode or retval.get("return_code", 0)

        pynets_wf = retval.get("workflow", None)
        work_dir = retval.get("work_dir")
        plugin_settings = retval.get("plugin_settings", None)
        plugin_settings = retval.get("plugin_settings", None)
        execution_dict = retval.get("execution_dict", None)
        run_uuid = retval.get("run_uuid", None)

        retcode = retcode or int(pynets_wf is None)
        if retcode != 0:
            sys.exit(retcode)

        # Clean up master process before running workflow, which may create
        # forks
        gc.collect()

    mgr.shutdown()

    if bids_args.push_location:
        print(f"Pushing to s3 at {bids_args.push_location}.")
        push_buck, push_remo = cloud_utils.parse_path(bids_args.push_location)
        for id in id_list:
            cloud_utils.s3_push_data(
                push_buck,
                push_remo,
                output_dir,
                modality,
                subject=id.split("_")[0],
                session=id.split("_")[1],
                creds=creds,
            )

    sys.exit(0)

    return
Example #36
0
from boto3.session import Session
import logging
import json
import os

# Update the root logger to get messages at DEBUG and above
logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger("botocore").setLevel(logging.CRITICAL)
logging.getLogger("boto3").setLevel(logging.CRITICAL)
logging.getLogger("urllib3.connectionpool").setLevel(logging.CRITICAL)

DATA_BUCKET = os.getenv("DATA_BUCKET")

#s3 = boto3.client("s3")
s3 = Session(profile_name="default").client("s3")


def write_to_s3(data,
                s3_key,
                bucket=DATA_BUCKET,
                encrypt_mode=None,
                kms_key=None):
    """
    :param encrypt_mode: None, AES256, or aws:kms
    :param kms_key: None or arn or alias arn
    """

    # Athena requires output each item to separate line
    # No outer bracket, no commas at the end of the line
    # {"k1": "v1"}
    # {"k1": "v2"}
Example #37
0
def create_app(test_config=None):
    # create and configure the app
    app = Flask(__name__, instance_relative_config=True)
    app.config.from_mapping(
        SECRET_KEY='dev',
        DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'),
    )

    boto_sess = Session(
        region_name='us-east-2',
        aws_access_key_id='AKIA2XEGOWGBGB36SBA2',
        aws_secret_access_key='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')

    app.config['DYNAMO_SESSION'] = boto_sess
    dynamo = Dynamo()
    dynamo.init_app(app)
    beerReviews = dynamo.tables['beerReviews']
    wineReviews = dynamo.tables['wineReviews']

    if test_config is None:
        # load the instance config, if it exists, when not testing
        app.config.from_pyfile('config.py', silent=True)
    else:
        # load the test config if passed in
        app.config.from_mapping(test_config)

    # ensure the instance folder exists
    try:
        os.makedirs(app.instance_path)
    except OSError:
        pass

    # a simple page that says hello
    @app.route("/")
    def home():
        return render_template("index.html")

    @app.route("/addBeer")
    def addBeer():
        return render_template("addBeer.html")

    @app.route("/searchBeer")
    def searchBeer():
        return render_template("searchBeer.html", data=[])

    @app.route("/addWine")
    def addWine():
        return render_template("addWine.html")

    @app.route("/searchWine")
    def searchWine():
        return render_template("searchWine.html", data=[])

    @app.route('/wines')
    def get_wines():
        query = '''SELECT * FROM Wine ORDER BY RANDOM() LIMIT 500'''
        wines = queryDB(query)
        if len(wines) == 0:
            abort(404)
        return jsonify({'success': True, 'wines': wines})

    @app.route('/beers')
    def get_beers():
        query = '''SELECT * FROM Beer ORDER BY RANDOM() LIMIT 500'''
        beers = queryDB(query)
        if len(beers) == 0:
            abort(404)
        return jsonify({'success': True, 'beers': beers})

    @app.route('/beers/search', methods=['POST'])
    def search_beers():
        searchName = request.form['name']
        if len(searchName) == 0:
            query = '''SELECT beer.id
                            , beer.brewery_id
                            , beer.name
                            , beer.style
                            , beer.abv
                            , brewery.name
                            , brewery.city
                            , brewery.state
                       FROM Beer beer, Brewery brewery
                       WHERE beer.brewery_id = brewery.id
                       ORDER BY RANDOM()
                       LIMIT 500'''
            beers = queryDB(query)
        else:
            query = '''SELECT beer.id
                            , beer.brewery_id
                            , beer.name
                            , beer.style
                            , beer.abv
                            , brewery.name
                            , brewery.city
                            , brewery.state
                       FROM Beer beer, Brewery brewery 
                       WHERE beer.brewery_id = brewery.id
                         AND beer.name LIKE ? 
                       ORDER BY RANDOM()
                       LIMIT 500'''
            query_params = ('%' + searchName + '%', )
            beers = queryDB(query, query_params)

        if len(beers) == 0:
            abort(404)

        return render_template("searchBeer.html", data=beers)

    @app.route('/wines/search', methods=['POST'])
    def search_wines():
        searchName = request.form['name']
        if len(searchName) == 0:
            query = '''SELECT wine.id
                            , wine.winery_id
                            , wine.name
                            , wine.variety
                            , wine.rating
                            , winery.name 
                            , winery.city
                            , winery.state
                        FROM Wine wine, Winery winery
                        WHERE wine.winery_id = winery.id
                        ORDER BY RANDOM()
                        LIMIT 500'''
            wines = queryDB(query)
        else:
            query = '''SELECT wine.id
                            , wine.winery_id
                            , wine.name
                            , wine.variety
                            , wine.rating
                            , winery.name 
                            , winery.city
                            , winery.state
                    FROM Wine wine, Winery winery
                    WHERE wine.winery_id = winery.id
                      AND wine.name LIKE ?
                    ORDER BY RANDOM()
                    LIMIT 500'''
            query_params = ('%' + searchName + '%', )
            wines = queryDB(query, query_params)

        if len(wines) == 0:
            abort(404)

        return render_template("searchWine.html", data=wines)

    @app.route('/beers/<int:beer_id>', methods=['GET'])
    def get_beer(beer_id):
        beer = queryBeerByID(beer_id)
        if len(beer) == 0:
            abort(404)
        return render_template("updateBeer.html", data=beer_id)

    @app.route('/wines/<int:wine_id>', methods=['GET'])
    def get_wine(wine_id):
        wine = queryWineByID(wine_id)
        if len(wine) == 0:
            abort(404)
        return render_template("updateWine.html", data=wine_id)

    @app.route('/deleteWine/<int:wine_id>', methods=['POST'])
    def delete_wine(wine_id):
        query = '''DELETE FROM wine WHERE id = ?'''
        query_params = (wine_id, )
        try:
            queryDB(query, query_params)
            return jsonify({'success': True, 'deleted': wine_id})

        except Exception:
            abort(422)

    @app.route('/deleteBeer/<int:beer_id>', methods=['POST'])
    def delete_beer(beer_id):
        query = '''DELETE FROM beer WHERE id = ?'''
        query_params = (beer_id, )
        try:
            queryDB(query, query_params)
            return jsonify({'success': True, 'deleted': beer_id})

        except Exception:
            abort(422)

    @app.route('/wines', methods=['POST'])
    def create_wine():
        data = request.form
        name = data.get('name') or None
        variety = data.get('variety') or None
        winery_name = data.get('winery') or None
        rating = data.get('rating') or None
        city = data.get('city') or None
        state = data.get('state') or None

        try:
            existing_wine = '''SELECT *
                                FROM wine
                                WHERE name = ?
                                AND variety = ?
                                '''
            wine_params = (
                name,
                variety,
            )
            existing_data = queryDB(existing_wine, wine_params)

            if len(existing_data) == 0:
                existing_winery = '''SELECT id
                                FROM winery
                                WHERE wname = ?
                                    AND city = ?
                                    AND state = ?
                                '''
                winery_params = (
                    winery_name,
                    city,
                    state,
                )
                existing_data = queryDB(existing_winery, winery_params)
                if len(existing_data) > 0:
                    winery_id = existing_data[0][0]
                else:
                    id_query = '''SELECT MAX(id) AS max_id FROM winery'''
                    winery_id = queryDB(id_query)[0][0] + 1

                    query = '''INSERT INTO winery (id, name, city, state, country)
                                            VALUES (?,?,?,?,?)
                                            '''
                    params = (
                        winery_id,
                        winery_name,
                        city,
                        state,
                        'US',
                    )
                    queryDB(query, params)

                id_query = '''SELECT MAX(id) AS max_id FROM wine'''
                new_wine_id = queryDB(id_query)[0][0] + 1
                query = '''INSERT INTO wine (id, name, variety, rating, winery_id)
                                    VALUES (?,?,?,?,?)
                                    '''
                params = (
                    new_wine_id,
                    name,
                    variety,
                    rating,
                    winery_id,
                )
                queryDB(query, params)

            return jsonify({'success': True, 'created': new_wine_id})
        except Exception:
            abort(422)

    @app.route('/beers', methods=['POST'])
    def create_beer():
        data = request.form
        name = data.get('name') or None
        style = data.get('style') or None
        brewery_name = data.get('brewery') or None
        abv = data.get('alcoholContent') or None
        city = data.get('city') or None
        state = data.get('state') or None

        try:
            existing_beer = '''SELECT *
                                FROM beer
                                WHERE name = ?
                                AND style = ?
                                '''
            beer_params = (
                name,
                style,
            )
            existing_data = queryDB(existing_beer, beer_params)

            if len(existing_data) == 0:
                existing_brewery = '''SELECT id
                                FROM brewery
                                WHERE name = ?
                                    AND city = ?
                                    AND state = ?
                                '''
                brewery_params = (
                    brewery_name,
                    city,
                    state,
                )
                existing_data = queryDB(existing_brewery, brewery_params)
                if len(existing_data) > 0:
                    brewery_id = existing_data[0][0]
                else:
                    id_query = '''SELECT MAX(id) AS max_id FROM brewery'''
                    brewery_id = queryDB(id_query)[0][0] + 1

                    query = '''INSERT INTO brewery (id, name, city, state, country)
                                            VALUES (?,?,?,?,?)
                                            '''
                    params = (
                        brewery_id,
                        brewery_name,
                        city,
                        state,
                        'US',
                    )
                    queryDB(query, params)

                id_query = '''SELECT MAX(id) AS max_id FROM beer'''
                new_beer_id = queryDB(id_query)[0][0] + 1
                query = '''INSERT INTO beer (id, name, style, abv, brewery_id)
                                    VALUES (?,?,?,?,?)
                                    '''
                params = (
                    new_beer_id,
                    name,
                    style,
                    abv,
                    brewery_id,
                )
                queryDB(query, params)

            return jsonify({'success': True, 'created': new_beer_id})
        except Exception:
            abort(422)

    @app.route('/wines/<int:wine_id>', methods=['POST'])
    def edit_wine(wine_id):
        existing_data = queryWineByID(wine_id)
        data = request.form

        name = data.get('name')
        variety = data.get('variety')
        rating = data.get('rating')
        if name or variety or rating:
            if len(name) == 0:
                name = existing_data[0][0]
            if len(variety) == 0:
                variety = existing_data[0][1]
            if len(rating) == 0:
                rating = existing_data[0][2]
            query = '''UPDATE wine SET name = ?, 
                            variety = ?, 
                            rating = ? 
                            WHERE id = ?'''
            params = (
                name,
                variety,
                rating,
                wine_id,
            )
            queryDB(query, params)

        winery_id = existing_data[0][7]
        winery_name = data.get('winery_name')
        city = data.get('city')
        state = data.get('state')
        if winery_name or city or state:
            if len(winery_name) == 0:
                winery_name = existing_data[0][3]
            if len(city) == 0:
                city = existing_data[0][4]
            if len(state) == 0:
                state = existing_data[0][5]
            query = '''UPDATE winery SET name = ?, 
                                city = ?, 
                                state = ? 
                                WHERE id = ?'''
            params = (
                winery_name,
                city,
                state,
                winery_id,
            )
            queryDB(query, params)

        return jsonify({'success': True, 'updated': winery_id})

    @app.route('/beers/<int:beer_id>', methods=['POST'])
    def edit_beer(beer_id):
        existing_data = queryBeerByID(beer_id)
        data = request.form

        try:
            name = data.get('name')
            style = data.get('style')
            abv = data.get('alcoholContent')
            if name or style or abv:
                if len(name) == 0:
                    name = existing_data[0][0]
                if len(style) == 0:
                    style = existing_data[0][1]
                if len(abv) == 0:
                    abv = existing_data[0][2]
                query = '''UPDATE beer 
                           SET name = ?, 
                                style = ?, 
                                abv = ? 
                                WHERE id = ?'''
                params = (
                    name,
                    style,
                    abv,
                    beer_id,
                )
                queryDB(query, params)

            brewery_id = existing_data[0][7]
            brewery_name = data.get('brewery_name')
            city = data.get('city')
            state = data.get('state')
            if brewery_name or city or state:
                if len(brewery_name) == 0:
                    brewery_name = existing_data[0][3]
                if len(city) == 0:
                    city = existing_data[0][4]
                if len(state) == 0:
                    state = existing_data[0][5]
                query = '''UPDATE brewery SET name = ?, 
                                city = ?, 
                                state = ? 
                                WHERE id = ?'''
                params = (
                    brewery_name,
                    city,
                    state,
                    brewery_id,
                )
                queryDB(query, params)

            return jsonify({'success': True, 'updated': beer_id})
        except Exception:
            abort(422)

    @app.route('/localbeers')
    def localBeers():
        query = '''SELECT *
                   FROM local_beers
                   LIMIT 500'''
        local_beers = queryDB(query)
        if len(local_beers) == 0:
            abort(404)

        return render_template("localBeers.html", data=local_beers)

    @app.route('/heatmap')
    def heatMap():
        map = folium.Map(location=[38, -98], zoom_start=5)
        locations = getLocations()
        address_latlng = []
        for location in locations:
            address = geocoder.osm(location[0] + ', ' + location[1])
            if address.lat and address.lng:
                address_latlng.append([address.lat, address.lng])
        HeatMap(address_latlng).add_to(map)
        return map._repr_html_()

    @app.route('/beers/<int:beer_id>/reviews', methods=['GET'])
    def get_beer_reviews(beer_id):
        query = '''SELECT * FROM beer 
                        WHERE id = ?'''
        param = (beer_id, )
        name = queryDB(query, param)[0][1]
        reviews = beerReviews.query(
            KeyConditionExpression=Key('beer_id').eq(beer_id))['Items']
        return render_template("beerReviews.html",
                               data={
                                   'id': beer_id,
                                   'name': name,
                                   'reviews': reviews
                               })

    @app.route('/beers/<int:beer_id>/reviews/add')
    def addBeerReview(beer_id):
        return render_template("addBeerReviews.html", id=beer_id)

    @app.route('/beers/<int:beer_id>/reviews', methods=['POST'])
    def add_beer_review(beer_id):
        data = request.form
        review = {}
        review['beer_id'] = beer_id
        review['username'] = data.get('username')
        review['date'] = date.today().strftime("%Y-%m-%d")
        attributes = ['text', 'taste', 'smell', 'look', 'feel', 'overall']
        for attribute in attributes:
            if data.get(attribute):
                review[attribute] = data.get(attribute)
        try:
            beerReviews.put_item(Item=review)
            return jsonify({'success': True})

        except Exception:
            abort(422)

    @app.route('/wines/<int:wine_id>/reviews', methods=['GET'])
    def get_wine_reviews(wine_id):
        query = '''SELECT * FROM wine 
                        WHERE id = ?'''
        param = (wine_id, )
        name = queryDB(query, param)[0][1]
        reviews = wineReviews.query(
            KeyConditionExpression=Key('wine_id').eq(wine_id))['Items']
        return render_template("wineReviews.html",
                               data={
                                   'id': wine_id,
                                   'name': name,
                                   'reviews': reviews
                               })

    @app.route('/wines/<int:wine_id>/reviews/add')
    def addWineReview(wine_id):
        return render_template("addWineReviews.html", id=wine_id)

    @app.route('/wines/<int:wine_id>/reviews', methods=['POST'])
    def add_wine_review(wine_id):
        data = request.form
        review = {}
        review['wine_id'] = wine_id
        review['reviewers'] = data.get('username')
        review['date'] = date.today().strftime("%Y-%m-%d")
        attributes = ['text', 'taste', 'smell', 'look', 'feel', 'overall']
        for attribute in attributes:
            if data.get(attribute):
                review[attribute] = data.get(attribute)
        try:
            wineReviews.put_item(Item=review)
            return jsonify({'success': True})

        except Exception:
            abort(422)

    def getLocations():
        query = '''SELECT city
                        , state
                    FROM brewery
                    UNION ALL
                    SELECT city
                        , state
                    FROM winery
                    LIMIT 100'''
        locations = queryDB(query)
        return locations

    def queryDB(query, params=None):
        with sqlite3.connect(DATABASE) as conn:
            try:
                cursor = conn.cursor()
                if params:
                    cursor.execute(query, params)
                    result = cursor.fetchall()
                else:
                    cursor.execute(query)
                    result = cursor.fetchall()
                cursor.close()
            except Exception:
                conn.rollback()
            finally:
                return result

    def queryWineByID(wine_id):
        query = '''SELECT wine.name
                        , wine.variety
                        , wine.rating
                        , winery.name  AS winery_name 
                        , winery.city
                        , winery.state
                        , wine.id
                        , wine.winery_id
                FROM Wine wine, Winery winery
                WHERE wine.winery_id = winery.id 
                    AND wine.id = ?'''
        return queryDB(query, (wine_id, ))

    def queryBeerByID(beer_id):
        query = '''SELECT beer.name
                        , beer.style
                        , beer.abv
                        , brewery.name   AS brewery_name
                        , brewery.city
                        , brewery.state
                        , beer.id
                        , beer.brewery_id
                FROM Beer beer, Brewery brewery 
                WHERE beer.brewery_id = brewery.id 
                    AND beer.id = ?'''
        return queryDB(query, (beer_id, ))

    @app.errorhandler(404)
    def not_found(error):
        return jsonify({
            "success": False,
            "error": 404,
            "message": "resource not found"
        }), 404

    @app.errorhandler(422)
    def unprocessable(error):
        return jsonify({
            "success": False,
            "error": 422,
            "message": "unprocessable"
        }), 422

    @app.errorhandler(400)
    def bad_request(error):
        return jsonify({
            "success": False,
            "error": 400,
            "message": "bad request"
        }), 400

    @app.errorhandler(500)
    def server_error(error):
        return jsonify({
            "success": False,
            "error": 500,
            "message": "internal server error"
        }), 500

    from . import db
    db.init_app(app)

    return app
Example #38
0
class Cleaner:
    def __init__(self, config):
        self.config = config
                            
    def clean_all(self):
        print('Cleaning region specific resources...')
        for region in self.config['region_names']:
            print('\nRegion: ' + region)
            self.session = Session(
                aws_access_key_id=self.config['aws_access_key_id'],
                aws_secret_access_key=self.config['aws_secret_access_key'],
                region_name=region
            )
            self.clean_ec2()
        print('Cleaning region agnostic resources...')
        self.clean_iam()
        return True
        
    def clean_ec2(self):
        print('START ec2 clean')
        ec2 = self.session.resource('ec2')
        # clean instances
        for instance in ec2.instances.all():
            print('Terminating instance: ' + instance.id)
            try:
                instance.terminate(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print('Unable to terminate instance: ' + instance.id)
                print(e.response['Error']['Code'])
            except:
                print("Unexpected error:", sys.exc_info()[0])
        
        # clean keypairs
        for keypair in ec2.key_pairs.all():
            print('Deleting keypair: ' + keypair.name)
            try:
                keypair.delete(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print('Unable to delete keypair: ' + keypair.name)
                print(e.response['Error']['Code'])
            except:
                print("Unexpected error:", sys.exc_info()[0])
        
        # clean volumes
        for volume in ec2.volumes.all():
            print('Deleting volume: ' + volume.id)
            try:
                volume.delete(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print(e.response['Error']['Code'])
                print('Unable to delete volume: ' + volume.id)
            except:
                print("Unexpected error:", sys.exc_info()[0])
        
        # clean images
        for image in ec2.images.filter(Owners=[self.config['aws_account_id']]):
            print('Deregistering images: ' + image.id)
            try:
                image.deregister(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print(e.response['Error']['Code'])
                print('Unable to delete volume: ' + image.id)
            except:
                print("Unexpected error:", sys.exc_info()[0])
                
        # clean snapshots
        filters=[
            {
                'Name':'owner-id',
                'Values':[self.config['aws_account_id']]
            }
        ]
        for snapshot in ec2.snapshots.filter(Filters=filters):
            print('Deleting snapshot: ' + snapshot.id)
            try:
                snapshot.delete(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print(e.response['Error']['Code'])
                print('Unable to delete snapshot: ' + snapshot.id)
            except:
                print("Unexpected error:", sys.exc_info()[0]) 
        
        # clean security groups
        for security_group in ec2.security_groups.all():
            if security_group.group_name == 'default':
                continue
            print('Deleting security group: ' + security_group.id)
            for ip_permission in security_group.ip_permissions:
                print('Deleting ingress rule: ' + ip_permission['IpProtocol'])
                security_group.revoke_ingress(IpPermissions=[ip_permission])
            for ip_permission in security_group.ip_permissions_egress:
                print('Deleting egress rule: ' + ip_permission['IpProtocol'])
                security_group.revoke_egress(IpPermissions=[ip_permission])
            try:
                security_group.delete(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print(e.response['Error']['Code'])
                print('Unable to delete security group: ' + security_group.id)
            except:
                print("Unexpected error:", sys.exc_info()[0]) 
        
        # clean elastic ips
        print('Checking VPC addresses')
        for vpc_address in ec2.vpc_addresses.filter(PublicIps=[]):
            print('Releasing elastic ip: ' + vpc_address.public_ip)
            try:
                vpc_address.release(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print(e.response['Error']['Code'])
                print('Unable to release elastic ip: ' + vpc_address.public_ip)
            except:
                print("Unexpected error:", sys.exc_info()[0]) 
        
        print('Checking classic addresses')
        for classic_address in ec2.classic_addresses.filter(PublicIps=[]):
            print('Releasing elastic ip: ' + classic_address.public_ip)        
            try:
                classic_address.release(DryRun=False)
            except botocore.exceptions.ClientError as e:
                print(e.response['Error']['Code'])
                print('Unable to release elastic ip: ' + classic_address.public_ip)
            except:
                print("Unexpected error:", sys.exc_info()[0])
                
        print('END ec2 clean')
        return True
        
        
            
    def clean_iam(self):
        print('START iam clean')
        client = self.session.client('iam')
        iam = self.session.resource('iam')
        for user in iam.users.all():
            print(user)
            if self.config['protected_users'].count(user.name) > 0:
                print('Skipping protected user: '******'END iam clean')
        return True
    
    def delete_user(self, user):
        #remove groups
        for group in user.groups.all():
           user.remove_group(GroupName=group.name)
        #remove keys
        for access_key in user.access_keys.all():
            access_key.delete()
        #remove signing certs
        for signing_certificate in user.signing_certificates.all():
            signing_certificate.delete()
        #remove inline policies
        for policy in user.policies.all():
            policy.delete()
        #remove attached policies
        for policy in user.attached_policies.all():
            user.detach_policy(PolicyArn=policy.arn)
        #delete login_profile
        try:
            user.LoginProfile().delete()
        except botocore.exceptions.ClientError as e:
            print(e.response['Error']['Code'])
            print('Unable to delete login profile: ' + user.name)
        except:
            print("Unexpected error:", sys.exc_info()[0])
        #finally delete user
        user.delete()
        return True
Example #39
0
 def get_client(self, session=None):
     if not session:
         session = Session()
     return session.client('lambda')
Example #40
0
VERBOSE_FORMATTING = (
    "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d "
    "%(task_id)s %(task_parent_id)s %(task_root_id)s "
    "%(message)s")
SIMPLE_FORMATTING = "[%(asctime)s] %(levelname)s %(task_root_id)s %(message)s"

LOG_DIRECTORY = os.getenv("LOG_DIRECTORY", BASE_DIR)
DEFAULT_LOG_FILE = os.path.join(LOG_DIRECTORY, "app.log")
LOGGING_FILE = os.getenv("DJANGO_LOG_FILE", DEFAULT_LOG_FILE)

if CW_AWS_ACCESS_KEY_ID:
    try:
        POD_NAME = ENVIRONMENT.get_value("APP_POD_NAME", default="local")
        BOTO3_SESSION = Session(
            aws_access_key_id=CW_AWS_ACCESS_KEY_ID,
            aws_secret_access_key=CW_AWS_SECRET_ACCESS_KEY,
            region_name=CW_AWS_REGION,
        )
        watchtower = BOTO3_SESSION.client("logs")
        watchtower.create_log_stream(logGroupName=CW_LOG_GROUP,
                                     logStreamName=POD_NAME)
        LOGGING_HANDLERS += ["watchtower"]
        WATCHTOWER_HANDLER = {
            "level": KOKU_LOGGING_LEVEL,
            "class": "watchtower.CloudWatchLogHandler",
            "boto3_session": BOTO3_SESSION,
            "log_group": CW_LOG_GROUP,
            "stream_name": POD_NAME,
            "formatter": LOGGING_FORMATTER,
            "use_queues": False,
            "create_log_group": False,
Example #41
0
    def __init__(
        self,
        database_alias,
        db_connection_params,
        destination_table_name,
        destination_schema_name,
        source_table_name,
        index_schema,
        index_table,
        query_file=None,
        distribution_key=None,
        sort_keys=None,
        index_column=None,
        index_sql=None,
        truncate_file=None,
        columns_to_drop=None,
        type_map=None,
        primary_key=None,
        not_null_date=False,
        full_refresh=False,
        rebuild=False,
        schema_file=None,
        data=None,
        append_only=False,
        db_template_data=None,
        include_comments=True,
    ):
        self.database_alias = database_alias
        self.db_host = db_connection_params.host
        self.db_port = db_connection_params.port
        self.db_name = db_connection_params.name
        self.db_user = db_connection_params.user
        self.db_password = db_connection_params.password
        self.db_additional_parameters = db_connection_params.additional
        self._columns = None
        self.columns_to_drop = columns_to_drop or []
        self.destination_table_name = destination_table_name
        self.destination_schema_name = destination_schema_name
        self.query_file = query_file
        self.schema_file = schema_file
        self.distribution_key = distribution_key
        self.sort_keys = sort_keys
        self.source_table_name = source_table_name
        self.index_column = index_column
        self.index_sql = index_sql
        self.truncate_file = truncate_file
        self.append_only = append_only
        self.full_refresh = full_refresh
        self.rebuild = rebuild
        self.type_map = self._clean_type_map(type_map)
        self.primary_key = (
            [primary_key] if isinstance(primary_key, str) else primary_key
        )
        self.not_null_date = not_null_date
        self.foreign_keys = []
        self.pks = []
        self._old_index_value = "notset"
        self._new_index_value = "notset"
        self._destination_table_status = None
        self.table_template_data = data
        self.db_template_data = db_template_data
        self.index_schema = index_schema
        self.index_table = index_table
        self.include_comments = include_comments

        self.date_key = datetime.datetime.strftime(
            datetime.datetime.utcnow(), "%Y%m%dT%H%M%S"
        )

        self.row_count = None
        self.upload_size = 0

        self.starttime = None
        self.endtime = None
        self.rows_inserted = None
        self.rows_deleted = None

        self.num_data_files = 0
        self.manifest_mode = False

        self.logger = logging.getLogger(f"druzhba.{database_alias}.{source_table_name}")
        self.s3 = Session().client("s3")
Example #42
0
class TableConfig(object):
    """Base class for a specific table. This class will have all methods
    that are engine agnostic--methods that only act with the host server
    or the data warehouse. All methods that interact with a specific engine
    should raise a NotImplementedError here and be overwritten with a
    subclass.

    Parameters
    ----------
    database_alias : str
        Config file name of the database
    db_connection_params : db.ConnectionParams
        DB connection parameters derived from a parsed connection string
    destination_table_name : str
        The name of the table where data should be loaded in the data
        warehouse
    destination_schema_name : str
        The name of the schema where data should be loaded in the data
        warehouse. Note : this should be "public" unless there's a good
        reason to segregate data.
    source_table_name : str
        name of the table in the source database. If a query_file is provided,
        this is purely for logging and monitoring
    query_file : str, optional
        path of a fully qualified SQL file that contains all the logic
        needed to extract a table.
    columns_to_drop : list of str, optional
        Defined by the YAML file, a list of columns that should not be
        imported into the data warehouse
    distribution_key : str, optional
        destination column name to use as the Redshift distkey
    sort_keys : list of str, optional
        destination column names to use as the Redshift sortkeys
    append_only : bool, optional
        Indicates that rows should only be inserted to this table and
        never updated or deleted. If True, primary_key has no effect.
        Default: False
    primary_key : str or list(str), optional
        Columns used to match records when updating the destination table. If
        not provided, primary keys are inferred from the source table. Has no
        effect if append_only is True
        Default: None
    index_column : str, optional
        Column used to identify new or updated rows in the source table.
        Persisted in the index table.
    index_sql : str, optional
        Custom SQL to be run against the source DB to find the current
        max index. Standard templating for the table applies and Druzhba expects
        the index to be returned in a column called `index_value`. Overrides
        index_column.
    truncate_file : str, optional
        Path to a fully qualified SQL file that contains logic to truncate
        the table when full-refreshing. Required to pass --full-refresh if a
        table defined is defined by a query_file.
    full_refresh : boolean, optional
        flag that forces a full deletion of a table prior to loading new data,
        rather than deleting only on matched PKs.  Setting True will conflict
        with index_column.
    rebuild : boolean, optional
        flag to rebuild the table completely. Implies full_refresh.
        Incompatible with query_file and certain conditions per-database-driver.
    type_map : dict, optional
        override type conversion from the source DB to redshift. This is
        set by YAML, and is an empty list of no configuration is provided.
        ex: {
               'tinyint(1)': 'smallint',
               'char(35)': 'varchar(70)',
               'bigint(20) unsigned': 'bigint'
            }
    include_comments : boolean, optional
        flag to specify whether or not to ingest table and column comments
        when building or rebuilding the target table.

    Attributes
    ----------
    columns : list of str
        All columns read from the source table
    foreign_keys : list of str
        generated from the `create table` syntax, a list of foreign key
        relationships to create for a table after all tables are created
    pks : list of str
        generated from the `create table` syntax, a list of source column
        names that define the PK
    key_name : str
        Once the dump has happened, the file is uploaded to s3 at this
        key location
    old_index_value : str
        Used in the `where` clause, this is the most recent index value for
        a given table currently in the data warehouse
    new_index_value : str
        Used in the `where` clause, this is the max index value for a given
        table currently in the source database.
    data : dict
        Read from table definition in the yaml file to supply data to the Jinja
        templating
    db_template_data : dict
        Read from db definition in the yaml file to supply data to the Jinja
        templating

    Notes
    -----
    All parameters are also set as attributes
    """

    DESTINATION_TABLE_OK = "ok"
    DESTINATION_TABLE_REBUILD = "rebuild"
    DESTINATION_TABLE_DNE = "non-existent"
    DESTINATION_TABLE_INCORRECT = "incorrect"

    max_file_size = 100 * 1024 ** 2

    def __init__(
        self,
        database_alias,
        db_connection_params,
        destination_table_name,
        destination_schema_name,
        source_table_name,
        index_schema,
        index_table,
        query_file=None,
        distribution_key=None,
        sort_keys=None,
        index_column=None,
        index_sql=None,
        truncate_file=None,
        columns_to_drop=None,
        type_map=None,
        primary_key=None,
        not_null_date=False,
        full_refresh=False,
        rebuild=False,
        schema_file=None,
        data=None,
        append_only=False,
        db_template_data=None,
        include_comments=True,
    ):
        self.database_alias = database_alias
        self.db_host = db_connection_params.host
        self.db_port = db_connection_params.port
        self.db_name = db_connection_params.name
        self.db_user = db_connection_params.user
        self.db_password = db_connection_params.password
        self.db_additional_parameters = db_connection_params.additional
        self._columns = None
        self.columns_to_drop = columns_to_drop or []
        self.destination_table_name = destination_table_name
        self.destination_schema_name = destination_schema_name
        self.query_file = query_file
        self.schema_file = schema_file
        self.distribution_key = distribution_key
        self.sort_keys = sort_keys
        self.source_table_name = source_table_name
        self.index_column = index_column
        self.index_sql = index_sql
        self.truncate_file = truncate_file
        self.append_only = append_only
        self.full_refresh = full_refresh
        self.rebuild = rebuild
        self.type_map = self._clean_type_map(type_map)
        self.primary_key = (
            [primary_key] if isinstance(primary_key, str) else primary_key
        )
        self.not_null_date = not_null_date
        self.foreign_keys = []
        self.pks = []
        self._old_index_value = "notset"
        self._new_index_value = "notset"
        self._destination_table_status = None
        self.table_template_data = data
        self.db_template_data = db_template_data
        self.index_schema = index_schema
        self.index_table = index_table
        self.include_comments = include_comments

        self.date_key = datetime.datetime.strftime(
            datetime.datetime.utcnow(), "%Y%m%dT%H%M%S"
        )

        self.row_count = None
        self.upload_size = 0

        self.starttime = None
        self.endtime = None
        self.rows_inserted = None
        self.rows_deleted = None

        self.num_data_files = 0
        self.manifest_mode = False

        self.logger = logging.getLogger(f"druzhba.{database_alias}.{source_table_name}")
        self.s3 = Session().client("s3")

    @classmethod
    def _clean_type_map(cls, type_map):
        if not type_map:
            return {}
        for k, v in type_map.items():
            type_map[k.lower()] = v
        return type_map

    @classmethod
    def validate_yaml_configuration(cls, yaml_config):
        """
        Validate YAML configuration. Note that this can differ slightly from runtime
        config, since full_refresh may be forced even when it would fail these checks.
        """

        table = yaml_config["source_table_name"]
        index_column = yaml_config.get("index_column")
        index_sql = yaml_config.get("index_sql")
        append_only = yaml_config.get("append_only")
        full_refresh = yaml_config.get("full_refresh")
        primary_key = yaml_config.get("primary_key")
        query_file = yaml_config.get("query_file")
        schema_file = yaml_config.get("schema_file")

        has_incremental_index = index_column or index_sql

        if not has_incremental_index and append_only:
            raise ConfigurationError("Append_only without incremental index", table)
        if full_refresh and append_only:
            raise ConfigurationError("Append_only with full_refresh", table)
        elif not has_incremental_index and not full_refresh:
            raise ConfigurationError(
                "Incremental update with no specified index", table
            )
        elif index_column and full_refresh:
            raise ConfigurationError("Full refresh with index_column", table)
        elif index_sql and full_refresh:
            raise ConfigurationError("Full refresh with index_sql", table)
        elif index_sql and index_column:
            raise ConfigurationError("index_sql and index_column", table)
        elif query_file and not primary_key and not append_only and not full_refresh:
            raise ConfigurationError(
                "incremental query_file without primary_key", table
            )
        elif query_file and not os.path.isfile(os.path.join(CONFIG_DIR, query_file)):
            raise ConfigurationError("nonexistent query_file", table)
        elif schema_file and not os.path.isfile(os.path.join(CONFIG_DIR, schema_file)):
            raise ConfigurationError("nonexistent schema_file", table)

    def validate_runtime_configuration(self):
        """
        Validate this instance's configuration state, which can differ from yaml
        configurations allowed by validate_yaml_configuration and runs after
        connecting to the database.
        """
        if self.rebuild and self.truncate_file:
            msg = (
                "Cannot rebuild a table with a truncate_file "
                "because it would not be correct to drop the table."
            )
            raise ConfigurationError(msg, self.source_table_name)
        elif self.rebuild and self.schema_file:
            msg = (
                "Cannot rebuild a table with a schema file, need "
                "support for passing in the table name to create."
            )
            raise ConfigurationError(msg, self.source_table_name)

    @property
    def s3_key_prefix(self):
        return "{}/{}.{}.{}".format(
            get_redshift().s3_config.prefix,
            self.database_alias,
            self.source_table_name,
            self.date_key,
        )

    def single_s3_data_key(self):
        """Returns the S3 path to upload a single avro file to"""
        if self.manifest_mode:
            raise TableStateError(
                "Attempted to treat a manifest upload as a single file"
            )

        return "{}.avro".format(self.s3_key_prefix)

    def manifest_s3_data_key(self):
        if not self.manifest_mode:
            raise TableStateError(
                "Attempted to treat a single file upload as a manifest"
            )

        return "{}.manifest".format(self.s3_key_prefix)

    def numbered_s3_data_key(self, file_num):
        return "{}/{:05d}.avro".format(self.s3_key_prefix, file_num)

    def next_s3_data_file_key(self):
        if self.manifest_mode:
            return self.numbered_s3_data_key(self.num_data_files)
        else:
            return self.single_s3_data_key()

    @property
    def copy_target_key(self):
        if self.manifest_mode:
            return self.manifest_s3_data_key()
        else:
            return self.single_s3_data_key()

    @property
    def copy_target_url(self):
        return _s3_url(get_redshift().s3_config.bucket, self.copy_target_key)

    def data_file_keys(self):
        if self.manifest_mode:
            for fn in range(self.num_data_files):
                yield self.numbered_s3_data_key(fn)
        else:
            yield self.single_s3_data_key()

    @property
    def connection_vars(self):
        raise NotImplementedError

    @property
    def avro_type_map(self):
        raise NotImplementedError

    def get_query_from_file(self):
        env = Environment(
            loader=FileSystemLoader(os.path.join(CONFIG_DIR)),
            autoescape=select_autoescape(["sql"]),
            undefined=StrictUndefined,
        )
        template = env.get_template(self.query_file)
        return template.render(
            db=self.db_template_data,
            table=self.table_template_data,
            run=self.run_template_data,
        )

    def get_sql_description(self, sql):
        raise NotImplementedError

    def get_query_sql(self):
        if self.query_file:
            return self.get_query_from_file()
        else:
            return self._get_query_sql() + self.where_clause()

    def _get_query_sql(self):
        raise NotImplementedError

    @property
    def columns(self):
        if not self._columns:
            _table_attributes, columns = self.get_sql_description(self.get_query_sql())
            self._columns = [column[0] for column in columns]
        return self._columns

    def query(self, sql):
        raise NotImplementedError

    def query_fetchone(self, sql):
        results = self.query(sql)
        return next(results)

    def _check_index_values(self):
        """Returns true if new index is greater than old index and defined"""
        if self.full_refresh:
            if self.index_column or self.index_sql:
                msg = (
                    "Index was found, but %s was forced. "
                    "Old index value will be ignored, but new "
                    "index value will be recorded."
                )
                self.logger.info(msg, "rebuild" if self.rebuild else "full-refresh")
            return True

        if self.append_only:
            return True

        # If there is no previous index, we're fine
        if self.old_index_value is None:
            return True

        # There's an old index but can't load a new value.
        if self.new_index_value is None and self.old_index_value is not None:
            msg = "Index expected but not found. Last value was %s. Dumping full table"
            self.logger.warning(msg, self.old_index_value)
            return False

        try:
            # old_index_value comes in as a unicode new_index_value as the sql
            # type
            if isinstance(self.new_index_value, int):
                is_inverted = int(self.old_index_value) > self.new_index_value
            elif isinstance(self.new_index_value, datetime.datetime):
                old_index_dt = datetime.datetime.strptime(
                    self.old_index_value, "%Y-%m-%d %H:%M:%S.%f"
                )
                is_inverted = old_index_dt > self.new_index_value
            else:
                self.logger.warning(
                    "Unknown type %s for index %s",
                    type(self.new_index_value),
                    self.old_index_value,
                )
                return False
        except (ValueError, TypeError) as ex:
            self.logger.warning("Could not check index: %s", str(ex))
            return False

        if is_inverted:
            self.logger.warning(
                "Index value has decreased for table %s.%s. "
                "May need to do full refresh",
                self.db_name,
                self.source_table_name,
            )
        return not is_inverted

    def row_generator(self):
        sql = self.get_query_sql()
        self._check_index_values()
        self.logger.info("Extracting %s table %s", self.db_name, self.source_table_name)
        self.logger.debug("Running SQL: %s", sql)
        return self.query(sql)

    @property
    def run_template_data(self):
        return {
            "destination_schema_name": self.destination_schema_name,
            "destination_table_name": self.destination_table_name,
            "db_name": self.db_name,
            "source_table_name": self.source_table_name,
            "index_column": self.index_column,
            "new_index_value": self.new_index_value,
            "old_index_value": self.old_index_value,
        }

    def _load_old_index_value(self):
        """Sets and gets the index_value property, retrieved from Redshift

        Returns
        -------
        index_value : variable
            Since index_value can vary from table to table, this can be many
            different types. Most common will be a datetime or int, but
            could also be a date or string. Returns None if no previous index
            value found
        """

        query = f"""
        SELECT index_value
          FROM "{self.index_schema}"."{self.index_table}"
         WHERE datastore_name = %s
           AND database_name = %s
           AND table_name = %s
        ORDER BY created_ts DESC
        LIMIT 1;
        """
        self.logger.debug("Querying Redshift for last updated index")
        with get_redshift().cursor() as cur:
            cur.execute(
                query, (self.database_alias, self.db_name, self.source_table_name)
            )
            index_value = cur.fetchone()

        if index_value is None:
            self.logger.info(
                "No index found. Dumping entire table: %s.", self.source_table_name
            )
            return index_value
        else:
            self.logger.info("Index found: %s", index_value[0])
            return index_value[0]

    @property
    def old_index_value(self):
        if self.full_refresh:
            return None
        # we use 'notset' rather than None because None is a valid output
        if self._old_index_value is "notset":
            self._old_index_value = self._load_old_index_value()
        return self._old_index_value

    def _load_new_index_value(self):
        # Abstract to support DB-specific quoting
        raise NotImplementedError

    @property
    def new_index_value(self):
        # we use 'notset' rather than None because None is a valid output
        if self._new_index_value is "notset":
            if self.index_sql:
                env = Environment(
                    autoescape=select_autoescape(["sql"]), undefined=StrictUndefined
                )
                template = env.from_string(self.index_sql)
                query = template.render(
                    db=self.db_template_data,
                    table=self.table_template_data,
                    # Cannot include the run template data
                    # here because we do not know the index values yet
                )
                self._new_index_value = self.query_fetchone(query)["index_value"]

            elif self.index_column:
                self._new_index_value = self._load_new_index_value()
            else:
                self._new_index_value = None
        if self.query_file and self.index_sql and self._new_index_value is None:
            # Handles a special case where the index_sql query returns no rows
            # but the custom sql file is expecting both old and new index values
            return 0
        return self._new_index_value

    def where_clause(self):
        """Method for filtering get_data, if tables are able to be
        sliced on some index value.

        Returns
        -------
        str
            valid SQL featuring just the WHERE clause
        """
        where_clause = "\nWHERE "
        if not self.index_column or self.full_refresh:
            # If no index_column, there is no where clause. The whole
            # source table is dumped.
            return ""

        if self.new_index_value is None:
            # Either the table is empty or the index_column is all NULL
            return ""

        if self.old_index_value:
            # This should always happen except on the initial load
            where_clause += "{} > '{}' AND ".format(
                self.index_column, self.old_index_value
            )

        where_clause += "{} <= '{}'".format(self.index_column, self.new_index_value)
        return where_clause

    def get_destination_table_columns(self):
        query = """
        SELECT "column"
          FROM pg_table_def
         WHERE schemaname = %s
           AND tablename = %s;
        """

        with get_redshift().cursor() as cur:
            self.set_search_path(cur)
            cur.execute(
                query, (self.destination_schema_name, self.destination_table_name)
            )
            results = cur.fetchall()

        return [x[0] for x in results]

    def get_destination_table_status(self):
        """Queries the data warehouse to determine if the desired
        destination table exists and if so, if it matches the expected
        configuration.

        Returns
        -------
        str
           Representing our plan for what to do with the destination table
           Includes:
            - DESTINATION_TABLE_DNE -> build it if possible
            - DESTINATION_TABLE_REBUILD -> rebuild it
            - DESTINATION_TABLE_OK -> leave it
            - DESTINATION_TABLE_INCORRECT -> error
        """
        dw_columns = set(self.get_destination_table_columns())
        source_columns = set(self.columns)
        expected = source_columns - set(self.columns_to_drop)
        unexpected_dw_columns = dw_columns - expected
        unexpected_source_columns = expected - dw_columns

        if len(dw_columns) == 0:
            self.logger.info("Destination table does not exist.")
            return self.DESTINATION_TABLE_DNE
        elif self.rebuild:
            # We're rebuilding it so we don't care if it's right,
            # so exit before we log any errors
            self.logger.info("Attempting to rebuild destination table.")
            return self.DESTINATION_TABLE_REBUILD
        elif dw_columns == expected:
            return self.DESTINATION_TABLE_OK
        elif len(unexpected_dw_columns) > 0:
            msg = (
                "Columns exist in the warehouse table that are not in "
                "the source: `%s`"
            )
            self.logger.warning(msg, "`, `".join(unexpected_dw_columns))
            return self.DESTINATION_TABLE_INCORRECT
        elif len(unexpected_source_columns) > 0:
            msg = (
                "Columns exist in the source table that are not in the "
                + "warehouse. Skipping column(s): `%s`"
            )
            self.logger.warning(msg, "`, `".join(unexpected_source_columns))

            # Copy from avro will just ignore the extra columns so we can proceed
            return self.DESTINATION_TABLE_OK
        else:
            raise RuntimeError("Unhandled case in get_destination_table_status")

    def query_description_to_avro(self, sql):
        _table_attributes, columns = self.get_sql_description(sql)
        fields = []

        for col_desc in columns:
            col_name = col_desc[0]
            schema = {"name": col_name}
            try:
                col_type = col_desc[1].split("(")[0]
            except AttributeError:
                col_type = col_desc[1]
            if col_type in self.avro_type_map["string"]:
                schema["type"] = ["null", "string"]
            elif col_type in self.avro_type_map["int"]:
                schema["type"] = ["null", "int"]
            elif col_type in self.avro_type_map["double"]:
                schema["type"] = ["null", "double"]
            elif col_type in self.avro_type_map["long"]:
                schema["type"] = ["null", "long"]
            elif col_type in self.avro_type_map["boolean"]:
                schema["type"] = ["null", "boolean"]
            elif col_type in self.avro_type_map["decimal"]:
                # fastavro now supports decimal types, but Redshift does not
                schema["type"] = ["null", "string"]
            else:
                self.logger.warning(
                    "unmatched data type for column %s in %s table %s",
                    col_desc[0],
                    self.db_name,
                    self.source_table_name,
                )
                schema["type"] = ["null", "string"]

            fields.append(schema)
        return fields

    def set_last_updated_index(self):
        """Adds a new index to the pipeline_table_index table for updated
        tables
        """
        if self.new_index_value is None:
            return

        query = f"""
        INSERT INTO "{self.index_schema}"."{self.index_table}" VALUES
        (%s, %s, %s, %s)
        """

        if isinstance(self.new_index_value, int):
            new_index_value = str(self.new_index_value)
        elif isinstance(self.new_index_value, datetime.datetime):
            new_index_value = self.new_index_value.strftime("%Y-%m-%d %H:%M:%S.%f")
        else:
            msg = "Don't know how to handle index {} of type {}".format(
                self.new_index_value, str(type(self.new_index_value))
            )
            raise TypeError(msg)

        self.logger.info("Updating index table")
        with get_redshift().cursor() as cur:
            args = (
                self.database_alias,
                self.db_name,
                self.source_table_name,
                new_index_value,
            )
            self.logger.debug(cur.mogrify(query, args))
            cur.execute(query, args)

    def create_table_keys(self, distkey=None, sortkeys=None):
        output = ""
        distkey = distkey or self.distribution_key
        if distkey:
            output += "distkey({})\n".format(distkey)

        sortkeys = sortkeys or self.sort_keys
        if sortkeys:
            output += "compound " if len(sortkeys) > 1 else ""
            output += "sortkey({})\n".format(",".join(sortkeys))
        return output

    def query_to_redshift_create_table(self, sql, table_name):
        raise NotImplementedError

    def check_destination_table_status(self):
        """Get the source table schema, convert it to Redshift compatibility,
        and ensure the table exists as expected in the data warehouse.

        Sets self._destination_table_status if we can proceed, or raises
        if not.

        Raises
        ------
        InvalidSchemaError
            Raised when the target table has columns not recognized in the source
            table, and we're not rebuilding the table. (Unrecognized source
            columns will log but still proceed)
        MigrationError
            Raised when the target table needs to be create or rebuilt but
            cannot be done automatically.
        """
        self.logger.info("Getting CREATE TABLE command")

        self._destination_table_status = self.get_destination_table_status()

        if self._destination_table_status in (
            self.DESTINATION_TABLE_DNE,
            self.DESTINATION_TABLE_REBUILD,
        ):
            self.logger.info("Verifying that the table can be created.")
            try:
                # Only called to see if it raises, the actual table
                # will be created later
                self.query_to_redshift_create_table(
                    self.get_query_sql(), self.destination_table_name
                )
            except NotImplementedError:
                raise MigrationError(
                    "Automatic table creation was not implemented for "
                    "this database, manual migration needed."
                )
        elif self._destination_table_status == self.DESTINATION_TABLE_INCORRECT:
            raise InvalidSchemaError(
                "Extra columns exist in redshift table. Migration needed"
            )

    def register_extract_monitor(self, starttime, endtime):
        """Adds an entry into the extract monitor for a given extract task

        Parameters
        ----------
        starttime : datetime.datetime
            datetime object generated at the beginning of the data
            extraction
        endtime : datetime.datetime
            datetime object generated at the end of the data extraction
        """

        query = """
        INSERT INTO "public"."table_extract_detail" VALUES (
            %(task_id)s, %(class_name)s, %(task_date_params)s,
            %(task_other_params)s, %(start_dt)s, %(end_dt)s, %(run_time_sec)s,
            %(manifest_path)s, %(data_path)s, %(output_exists)s, %(row_count)s,
            %(upload_size)s, %(exception)s
        );
        """
        args = {
            "task_id": "{}(alias={}, database={}, table={})".format(
                self.__class__.__name__,
                self.database_alias,
                self.db_name,
                self.source_table_name,
            ),
            "class_name": self.__class__.__name__,
            "task_date_params": None,
            "task_other_params": None,
            "start_dt": starttime.replace(microsecond=0),
            "end_dt": endtime.replace(microsecond=0),
            "run_time_sec": (endtime - starttime).total_seconds(),
            "manifest_path": self.copy_target_url,
            "data_path": "s3://{}/{}".format(
                get_redshift().s3_config.bucket, self.s3_key_prefix
            ),
            "output_exists": self.row_count > 0,
            "row_count": self.row_count,
            "upload_size": self.upload_size,
            "exception": None,
        }

        self.logger.info("Inserting record into table_extract_detail")
        with get_redshift().cursor() as cur:
            cur.execute(query, args)

    def register_load_monitor(self):
        """Adds an entry into the load monitor for a given load task

        Parameters
        ----------
        starttime : datetime
            datetime object generated at the beginning of the load
        endtime : datetime
            datetime object generated at the end of the load
        rows_inserted : int
            Total number of rows generated by the Redshift COPY command
        rows_deleted : int
            Count of rows deleted by primary key in the destination table
        load_size : int
            Size in bytes of the file in S3 used in the Redshift COPY
            command
        """

        query = """
            INSERT INTO "public"."table_load_detail" VALUES (
                %(task_id)s, %(class_name)s, %(task_date_params)s,
                %(task_other_params)s, %(target_table)s, %(start_dt)s,
                %(end_dt)s, %(run_time_sec)s, %(extract_task_update_id)s,
                %(data_path)s, %(manifest_cleaned)s, %(rows_inserted)s,
                %(rows_deleted)s, %(load_size)s, %(exception)s
            );
        """
        task_id = "{}(alias={}, database={}, table={})".format(
            self.__class__.__name__,
            self.database_alias,
            self.db_name,
            self.source_table_name,
        )
        target_table = "{}.{}".format(
            self.destination_schema_name, self.destination_table_name
        )
        args = {
            "task_id": task_id,
            "class_name": self.__class__.__name__,
            "task_date_params": None,
            "task_other_params": None,
            "target_table": target_table,
            "start_dt": self.starttime.replace(microsecond=0),
            "end_dt": self.endtime.replace(microsecond=0),
            "run_time_sec": (self.endtime - self.starttime).total_seconds(),
            "extract_task_update_id": task_id,
            "data_path": self.copy_target_url,
            "manifest_cleaned": False,
            "rows_inserted": self.rows_inserted,
            "rows_deleted": self.rows_deleted,
            "load_size": self.upload_size,
            "exception": None,
        }

        self.logger.info("Inserting record into table_load_detail")
        with get_redshift().cursor() as cur:
            cur.execute(query, args)

    def extract(self):
        """Serializes full db result set and uploads to s3

        The data will be uploaded either as a single file or as a set of files
        with a manifest
        """
        # TODO: Do we not currently execute the extract monitor?
        # starttime = datetime.datetime.utcnow()

        results_schema = self.query_description_to_avro(self.get_query_sql())
        results_iter = self.row_generator()

        done = False
        while not done:
            done = self.avro_to_s3(results_iter, results_schema)

        if self.num_data_files == 0:
            self.logger.info(
                "No data extracted; not uploading to s3 for %s table %s",
                self.db_name,
                self.source_table_name,
            )

        if self.manifest_mode:
            self.write_manifest_file()

        # endtime = datetime.datetime.utcnow()
        # self.register_extract_monitor(starttime, endtime)

    def avro_to_s3(self, results_iter, results_schema):
        """Attempts to serialize a result set to an AVRO file

        returns true if it complete writes the entire result_iter and false
        if there were records remaining when it hit the maximum file size.
        """
        with BytesIO() as f:
            complete, row_count = write_avro_file(
                f,
                results_iter,
                results_schema,
                self.destination_table_name,
                self.max_file_size,
            )

            if self.row_count is None:
                self.row_count = row_count
            else:
                self.row_count += row_count

            self.upload_size += f.tell()

            if not complete:
                self.manifest_mode = True

            if row_count > 0:
                self._upload_s3(
                    f, get_redshift().s3_config.bucket, self.next_s3_data_file_key()
                )
                self.num_data_files += 1

        return complete

    def write_manifest_file(self):
        if not self.manifest_mode:
            raise TableStateError("Cannot write manifest when not in manifest mode")

        entries = [
            {"url": _s3_url(get_redshift().s3_config.bucket, key), "mandatory": True}
            for key in self.data_file_keys()
        ]
        manifest = {"entries": entries}

        with BytesIO() as f:
            f.write(json.dumps(manifest).encode())
            self._upload_s3(
                f, get_redshift().s3_config.bucket, self.manifest_s3_data_key()
            )

    def _upload_s3(self, f, bucket, key):
        """
        Upload a file to this table's s3 key.

        Parameters
        ----------
        f : an open file handle.
        s3_path : string indicating the s3 location to write to
        """

        MB = 1024 ** 2
        s3_config = TransferConfig(multipart_threshold=10 * MB)
        f.seek(0)

        self.logger.info("Writing s3 file %s", _s3_url(bucket, key))

        retries = 3
        retries_remaining = retries
        while retries_remaining > 0:
            try:
                self.s3.upload_fileobj(f, bucket, key, Config=s3_config)
                self.logger.info("Wrote s3 file %s", _s3_url(bucket, key))
                return
            except KeyError:
                # retry on intermittent credential error
                retries_remaining -= 1
                if retries_remaining > 0:
                    time.sleep(3 * (retries - retries_remaining) ** 2)
                else:
                    raise

    def set_search_path(self, cursor):
        """This sets the search_path for a Redshift session.
        The default search_path is "'$user', public"; this replaces it with
        just the destination_schema_name.

        Parameters
        ----------
        cursor : Redshift cursor
            A cursor is passed in rather than generated because we need to
            modify the existing cursor and not create a new one
        """
        query = "SET search_path TO {};".format(self.destination_schema_name)
        self.logger.debug(query)
        cursor.execute(query)

    def get_delete_sql(self):
        if self.full_refresh:
            if self._destination_table_status == self.DESTINATION_TABLE_REBUILD:
                # We'll just drop it
                return ""
            if self.truncate_file:
                env = Environment(
                    loader=FileSystemLoader(os.path.join("datacfg")),
                    autoescape=select_autoescape(["sql"]),
                    undefined=StrictUndefined,
                )
                template = env.get_template(self.truncate_file)
                return template.render(
                    db=self.db_template_data,
                    table=self.table_template_data,
                    run=self.run_template_data,
                )
            else:
                return 'DELETE FROM "{}";'.format(self.destination_table_name)
        elif not self.append_only:
            if self.primary_key:
                # override from db yaml file
                pks = self.primary_key
            else:
                # pk column discovered from the source table
                pks = self.pks

            if not pks and self.index_column:
                raise InvalidSchemaError(
                    "Specifying an index column without primary key would "
                    "result in all records in the existing table being "
                    "deleted. If this is the desired behavior, run with "
                    "--full-refresh instead. If not, check if primary keys"
                    "can be inferred from the upstream database."
                )

            constraints = [
                '"{0}"."{2}" = "{1}"."{2}"'.format(
                    self.staging_table_name, self.destination_table_name, pk
                )
                for pk in pks
            ]
            constraint_string = " AND ".join(constraints)
            return 'DELETE FROM "{}" USING "{}" WHERE {};'.format(
                self.destination_table_name,
                self.staging_table_name,
                constraint_string,
            )
        else:
            # Should only land here when append_only
            # in which case we're not deleting
            return None

    def get_grant_sql(self, cursor):
        """Get SQL statements to restore permissions
        to the staging table after a rebuild."""

        get_permissions_sql = """
        SELECT
            use.usename = CURRENT_USER          AS "owned"
            , c.relacl                          AS "permissions"
        FROM pg_class c
            LEFT JOIN pg_namespace nsp ON c.relnamespace = nsp.oid
            LEFT JOIN pg_user use ON c.relowner = use.usesysid
        WHERE
            c.relkind = 'r'
            AND nsp.nspname = '{schema}'
            AND c.relname = '{table}'
        """.format(
            schema=self.destination_schema_name, table=self.destination_table_name
        )
        cursor.execute(get_permissions_sql)
        permissions_result = cursor.fetchall()

        if len(permissions_result) == 0:
            self.logger.info(
                "No existing permissions found for %s.%s",
                self.destination_schema_name,
                self.destination_table_name,
            )
            return None
        elif len(permissions_result) > 1:
            raise MigrationError("Got multiple permissions rows for table")

        is_owner, permissions_str = permissions_result[0]
        if not is_owner:
            raise MigrationError(
                "Can't rebuild target table because it has another owner"
            )

        self.logger.info(
            "Got existing permissions for table to add to %s: %s",
            self.staging_table_name,
            permissions_str,
        )
        permissions = Permissions.parse(permissions_str)
        if permissions is None:
            raise MigrationError(
                "Couldn't parse permissions {} to rebuild target table".format(
                    permissions_str
                )
            )

        grant_template = "GRANT {grant} ON {table} TO {group}{name};"
        grant_sqls = [
            grant_template.format(
                grant=g,
                table=self.staging_table_name,
                # Should we not restore users?
                group="GROUP " if p.is_group else "",
                name=p.name,
            )
            for p in permissions
            for g in p.grants
        ]
        return "\n".join(grant_sqls)

    @property
    def staging_table_name(self):
        return "{}_{}_staging".format(self.database_alias, self.destination_table_name)

    def load(self):
        """The Load phase of the pipeline. Takes a file in S3 and issues a
        Redshift COPY command to import the data into a staging table. It
        then upserts the data by deleting rows in the destination table
        that match on PK, inserts all rows from the staging table into the
        destination table, and then deletes the staging table.
        """

        self.starttime = datetime.datetime.utcnow()

        # Initializing Data
        delete_clause = self.get_delete_sql()
        staging_table = self.staging_table_name
        destination_table = self.destination_table_name
        is_normal_load = self._destination_table_status == self.DESTINATION_TABLE_OK
        is_rebuild = self._destination_table_status == self.DESTINATION_TABLE_REBUILD
        is_dne = self._destination_table_status == self.DESTINATION_TABLE_DNE

        with get_redshift().cursor() as cur:
            self.set_search_path(cur)

            # If table does not exist, create it
            if is_dne:
                create_table = self.query_to_redshift_create_table(
                    self.get_query_sql(), self.destination_table_name
                )
                cur.execute(create_table)
            elif not is_normal_load and not is_rebuild:
                raise RuntimeError(
                    "Invalid table status in redshift_copy: {}".format(
                        self._destination_table_status
                    )
                )

            # If there is no row updates, just skip copy and return
            if self.row_count == 0:
                return

            cur.execute("BEGIN TRANSACTION;")
            # Lock the table early to avoid deadlocks in many-to-one pipelines.
            query = generate_lock_query(destination_table)
            cur.execute(query)

            query = generate_drop_exists_query(staging_table)
            cur.execute(query)

            if is_rebuild:
                # Build staging table anew and grant it appropriate permissions
                self.logger.info(
                    "Creating staging table to rebuild %s", destination_table
                )
                create_staging_table = self.query_to_redshift_create_table(
                    self.get_query_sql(), staging_table
                )
                permissions_sql = self.get_grant_sql(cur)
                cur.execute(create_staging_table)
                if permissions_sql:
                    self.logger.info(
                        "Copying permissions onto %s:\n%s",
                        staging_table,
                        permissions_sql,
                    )
                    cur.execute(permissions_sql)
            else:
                # If not rebuilding, create staging with LIKE
                self.logger.info("Creating staging table %s", staging_table)
                query = generate_create_table_like_query(
                    staging_table, destination_table
                )
                cur.execute(query)

            # Issuing Copy Command
            self.logger.info("Issuing copy command")
            query = generate_copy_query(
                staging_table,
                self.copy_target_url,
                get_redshift().iam_copy_role,
                self.manifest_mode,
            )
            self.logger.debug(query)
            cur.execute(query)

            # Row delete and count logic
            if is_rebuild or (self.append_only and not self.full_refresh):
                self.rows_deleted = 0
            else:
                cur.execute(delete_clause)
                self.rows_deleted = cur.rowcount

            # Row insert and count logic
            if is_rebuild:
                self.logger.info("Swapping staging table into %s", destination_table)
                # DNE overrides rebuild, so we can assume the table exists
                query = generate_drop_query(destination_table)
                self.logger.debug(query)
                cur.execute(query)
                query = generate_rename_query(staging_table, destination_table)
                self.logger.debug(query)
                cur.execute(query)
                query = generate_count_query(destination_table)
                self.logger.debug(query)
                cur.execute(query)
                self.rows_inserted = cur.fetchall()[0]
            else:
                query = generate_insert_all_query(staging_table, destination_table)
                self.logger.debug(query)
                cur.execute(query)
                self.rows_inserted = cur.rowcount
                query = generate_drop_query(staging_table)
                self.logger.debug(query)
                cur.execute(query)
            cur.execute("END TRANSACTION;")
            self.register_and_cleanup()

    def register_and_cleanup(self):
        # Register in index table.
        self.set_last_updated_index()

        # Register in monitor table
        # self.endtime = datetime.datetime.utcnow()
        # self.register_load_monitor()

        # Clean up S3
        for key in self.data_file_keys():
            self.s3.delete_object(Bucket=get_redshift().s3_config.bucket, Key=key)

        if self.manifest_mode:
            self.s3.delete_object(
                Bucket=get_redshift().s3_config.bucket, Key=self.manifest_s3_data_key()
            )
Example #43
0
 def __init__(self):
     _session = Session(config.s3_config['access_key'], 
                        config.s3_config['secret_key'], 
                        region_name=config.s3_config['region'])
     self.client = _session.client('s3')
Example #44
0
class BucketWorker(Thread):
    def __init__(self, q, *args, **kwargs):
        self.q = q
        self.use_aws = CONFIG["aws_access_key"] and CONFIG["aws_secret"]

        if self.use_aws:
            self.session = Session(
                aws_access_key_id=CONFIG["aws_access_key"],
                aws_secret_access_key=CONFIG["aws_secret"]).resource("s3")
        else:
            self.session = requests.Session()
            self.session.mount(
                "http://",
                HTTPAdapter(pool_connections=ARGS.threads,
                            pool_maxsize=QUEUE_SIZE,
                            max_retries=0))

        super().__init__(*args, **kwargs)

    def run(self):
        while True:
            try:
                bucket_url = self.q.get()
                self.__check_boto(
                    bucket_url) if self.use_aws else self.__check_http(
                        bucket_url)
            except Exception as e:
                print(e)
                pass
            finally:
                self.q.task_done()

    def __check_http(self, bucket_url):
        check_response = self.session.head(S3_URL,
                                           timeout=3,
                                           headers={"Host": bucket_url})

        if not ARGS.ignore_rate_limiting and (check_response.status_code == 503
                                              and check_response.reason
                                              == "Slow Down"):
            self.q.rate_limited = True
            # add it back to the bucket for re-processing
            self.q.put(bucket_url)
        elif check_response.status_code == 307:  # valid bucket, lets check if its public
            new_bucket_url = check_response.headers["Location"]
            bucket_response = requests.request(
                "GET" if ARGS.only_interesting else "HEAD",
                new_bucket_url,
                timeout=3)

            if bucket_response.status_code == 200 and (
                    not ARGS.only_interesting or
                (ARGS.only_interesting and any(keyword in bucket_response.text
                                               for keyword in KEYWORDS))):
                cprint("Found bucket '{}'".format(new_bucket_url),
                       "green",
                       attrs=["bold"])
                self.__log(new_bucket_url)

    def __check_boto(self, bucket_url):
        bucket_name = bucket_url.replace(".s3.amazonaws.com", "")

        try:
            # just to check if the bucket exists. Throws NoSuchBucket exception if not
            self.session.meta.client.head_bucket(Bucket=bucket_name)

            if not ARGS.only_interesting or (
                    ARGS.only_interesting
                    and self.__bucket_contains_any_keywords(bucket_name)):
                owner = None
                acls = None

                try:
                    # todo: also check IAM policy as it can override ACLs
                    acl = self.session.meta.client.get_bucket_acl(
                        Bucket=bucket_name)
                    owner = acl["Owner"]["DisplayName"]
                    acls = ". ACLs = {} | {}".format(
                        self.__get_bucket_perms(acl, "AllUsers"),
                        self.__get_bucket_perms(acl, "AuthenticatedUsers"))
                except:
                    acls = ". ACLS = (could not read)"

                color = "green" if not owner else "magenta"
                cprint("Found bucket '{}'. Owned by '{}'{}".format(
                    bucket_url, owner if owner else "(unknown)", acls),
                       color,
                       attrs=["bold"])
                self.__log(bucket_url)
        except:
            pass

    def __get_bucket_perms(self, acl, group):
        group_uri = "http://acs.amazonaws.com/groups/global/%s" % group
        perms = [
            g["Permission"] for g in acl["Grants"]
            if g["Grantee"]["Type"] == "Group"
            and g["Grantee"]["URI"] == group_uri
        ]

        return "{}: {}".format(group, ", ".join(perms) if perms else "(none)")

    def __bucket_contains_any_keywords(self, bucket_name):
        try:
            objects = [
                o.key for o in self.session.Bucket(bucket_name).objects.all()
            ]
            return any(keyword in ",".join(objects) for keyword in KEYWORDS)
        except:
            return False

    def __log(self, new_bucket_url):
        global FOUND_COUNT
        FOUND_COUNT += 1

        if ARGS.log_to_file:
            with open("buckets.log", "a+") as log:
                log.write("%s%s" % (new_bucket_url, os.linesep))
Example #45
0
parser.add_argument('--consoleversion', required=False, default=1, help='console task definition version')

args = parser.parse_args()

print("Stack is " + args.stack + ".", flush=True)

# Writes out so it is usable in SSH
keyfile = open('/root/.ssh/id_rsa', 'w')
keyfile.write(args.sshkey)
keyfile.close()

os.chmod('/root/.ssh/id_rsa', 0o600)


session = Session(aws_access_key_id=args.key,
                  aws_secret_access_key=args.secret,
                  region_name=args.region)

cfnClient = session.client('cloudformation')
ecsClient = session.client('ecs')
ec2Client = session.client('ec2')

def stop_task(ecsClient, cluster, task):
    try:
        print("Cleaning up after ourselves...")
        ecsClient.stop_task(
            cluster=cluster,
            task=task
            )
        print(bcolors.OKGREEN + "Task stopped and removed succesfully, all done!" + bcolors.ENDC)
    except Exception as e:
Example #46
0
 def session(self):
     if not self._session:
         self._session = Session(aws_access_key_id=self.aws_access_key,
                                 aws_secret_access_key=self.aws_secret_key,
                                 region_name=self.region)
     return self._session
Example #47
0
config = Config()

# set up logging handlers from config
loggingHandlers = ['wsgi']
loggingHandlersConfig = {
    'wsgi': {
        'class': 'logging.StreamHandler',
        'stream': 'ext://flask.logging.wsgi_errors_stream',
        'formatter': 'colored'
    },
}
if not config.aws_local:
    # set up logging to AWS cloudwatch when not restricted to local AWS
    boto3_session = Session(aws_access_key_id=config.aws_access_keyid,
                            aws_secret_access_key=config.aws_secret_key,
                            region_name=config.aws_region)
    loggingHandlers.append('watchtower')
    loggingHandlersConfig['watchtower'] = {
        'level': 'DEBUG',
        'class': 'watchtower.CloudWatchLogHandler',
        'boto3_session': boto3_session,
        'log_group': 'watchtower',
        'stream_name': 'rocket2',
        'formatter': 'aws',
    }

# set up logging
loggingConfig = {
    'version': 1,
    'disable_existing_loggers': False,
Example #48
0
def S3KMSDemo():

    #
    # Create a KMS Client object
    #
    session = Session(profile_name='default', region_name='us-east-1')
    kms = session.client('kms')

    #
    # Generate a Data Key (encoded with my Master Key in KMS)
    #
    key = kms.generate_data_key(KeyId=MASTER_KEY_ARN, KeySpec='AES_256')
    keyPlain = key['Plaintext']
    keyCipher = key['CiphertextBlob']

    #
    # Encode a file with the data key
    #
    print('Initializing encryption engine')
    iv = Random.new().read(AES.block_size)
    chunksize = 64 * 1024
    encryptor = AES.new(keyPlain, AES.MODE_CBC, iv)

    print(f'KMS Plain text key = {base64.b64encode(keyPlain)} ')
    print(f'KMS Encrypted key  = {base64.b64encode(keyCipher)} ')

    in_filename = os.path.join(DIRECTORY, FILENAME)
    out_filename = in_filename + '.enc'
    filesize = os.path.getsize(in_filename)

    print('Encrypting file')
    with open(in_filename, 'rb') as infile:
        with open(out_filename, 'wb') as outfile:
            outfile.write(struct.pack('<Q', filesize))
            outfile.write(iv)

            chunk = infile.read(chunksize)
            while len(chunk) != 0:
                if len(chunk) % 16 != 0:
                    chunk += b' ' * (16 - len(chunk) % 16)
                outfile.write(encryptor.encrypt(chunk))
                chunk = infile.read(chunksize)

    #
    # Store encrypted file on S3
    # Encrypted Key will be stored as meta data
    #
    print('Storing encrypted file on S3')
    metadata = {'key': base64.b64encode(keyCipher).decode('ascii')}

    s3 = session.client('s3')
    s3.upload_file(out_filename,
                   S3_BUCKET,
                   out_filename,
                   ExtraArgs={'Metadata': metadata})
    os.remove(out_filename)

    ##
    ## Later ...
    ##

    #
    # Download Encrypted File and it's metadata
    #
    print('Download file and meta data from S3')
    transfer = S3Transfer(s3)
    transfer.download_file(S3_BUCKET, out_filename, out_filename)

    #retrieve meta data
    import boto3
    s3 = boto3.resource('s3')
    object = s3.Object(S3_BUCKET, out_filename)
    #print object.metadata

    keyCipher = base64.b64decode(object.metadata['key'])

    #decrypt encrypted key
    print('Decrypt ciphered key')
    key = kms.decrypt(CiphertextBlob=keyCipher)
    keyPlain = key['Plaintext']
    print(f'KMS Plain text key = {base64.b64encode(keyPlain)}')
    print(f'KMS Encrypted key  = {base64.b64encode(keyCipher)}')

    #
    # Decrypt the file
    #
    print('Decrypt the file')

    in_filename = out_filename
    out_filename = in_filename + '.jpg'
    filesize = os.path.getsize(in_filename)

    with open(in_filename, 'rb') as infile:
        origsize = struct.unpack('<Q', infile.read(struct.calcsize('Q')))[0]
        iv = infile.read(16)
        decryptor = AES.new(keyPlain, AES.MODE_CBC, iv)

        with open(out_filename, 'wb') as outfile:
            chunk = infile.read(chunksize)
            while len(chunk) != 0:
                outfile.write(decryptor.decrypt(chunk))
                chunk = infile.read(chunksize)

            outfile.truncate(origsize)

    # Cleanup S3
    object.delete()

    print(
        f'Done.\n\nYour file {out_filename} should be identical to original file {os.path.join(DIRECTORY, FILENAME)}'
    )
Example #49
0
 def get_session(cls):
     if cls.session is None:
         cls.session = Session(aws_access_key_id=settings.s3_access_key,
                               aws_secret_access_key=settings.s3_secret_key)
     return cls.session
Example #50
0
    def __init__(self, config,
                 measure_timestamp_diff=False,
                 compression=None):
        self.logger = logs.getLogger(self.__class__.__name__)
        self.logger.setLevel(get_model_verbose_level())
        self.credentials: Credentials =\
            Credentials.getCredentials(config)

        self.endpoint = config.get('endpoint', None)

        if self.credentials is None:
            msg: str = "NO CREDENTIALS provided for {0}."\
                .format(self.endpoint)
            self._report_fatal(msg)

        if self.credentials.get_type() != AWS_TYPE:
            msg: str = "EXPECTED aws credentials for {0}: {1}"\
                .format(self.endpoint, repr(self.credentials.to_dict()))
            self._report_fatal(msg)

        aws_key: str = self.credentials.get_key()
        aws_secret_key = self.credentials.get_secret_key()
        session = Session(aws_access_key_id=aws_key,
            aws_secret_access_key=aws_secret_key,
            region_name=config.get('region'))

        session.events.unregister('before-parameter-build.s3.ListObjects',
                          set_list_objects_encoding_type_url)

        self.client = session.client(
            's3',
            endpoint_url=self.endpoint,
            config=Config(signature_version='s3v4'))

        if compression is None:
            compression = config.get('compression')

        self.cleanup_bucket = config.get('cleanup_bucket', False)
        if isinstance(self.cleanup_bucket, str):
            self.cleanup_bucket = self.cleanup_bucket.lower() == 'true'
        self.bucket_cleaned_up: bool = False

        self.endpoint = self.client._endpoint.host

        self.bucket = config['bucket']
        try:
            buckets = self.client.list_buckets()
        except Exception as exc:
            msg: str = "FAILED to list buckets for {0}: {1}"\
                .format(self.endpoint, exc)
            self._report_fatal(msg)

        if self.bucket not in [b['Name'] for b in buckets['Buckets']]:
            try:
                self.client.create_bucket(Bucket=self.bucket)
            except Exception as exc:
                msg: str = "FAILED to create bucket {0} for {1}: {2}"\
                    .format(self.bucket, self.endpoint, exc)
                self._report_fatal(msg)

        super().__init__(StorageType.storageS3,
            self.logger,
            measure_timestamp_diff,
            compression=compression)
Example #51
0
def create_session(creds: TemporaryCredentials) -> Session:
    return Session(aws_access_key_id=creds.aws_access_key_id,
                   aws_secret_access_key=creds.aws_secret_access_key,
                   aws_session_token=creds.aws_session_token)
Example #52
0
from boto3.session import Session

s = Session()
ec2_regions = s.get_available_regions('dynamodb')
dynamodb_regions = s.get_available_regions('dynamodb')
print("ec2_regions:", ec2_regions)
print("dynamodb_regions:", dynamodb_regions)
Example #53
0
def s3_client(session: Session) -> S3Client:
    return S3Client(session.client('s3'))
Example #54
0
import boto3
from boto3.session import Session

session = Session(aws_access_key_id='AKIA1235678901234567',
                  aws_secret_access_key='asdfghjkouytresxnmkiutrfvbR9wJTeSL')
 def __init__(self, profile):
     session = Session(profile_name=profile)
     self.cloudwatch = session.client('cloudwatch')
Example #56
0
    "us-west-2", "ap-southeast-1", "us-east-1", "ap-northeast-1", "eu-west-1",
    "ap-southeast-2", "sa-east-1", "eu-central-1"
]

region = regions[7]
# 194094
# 66126
# 253606
# 58592
# 62393
# 52486
# 51281
# 46789

# use ~/.aws/credentials:
session = Session(profile_name="sre-readonly", region_name=region)


def get_available_volumes():

    ec2 = session.resource("ec2", region_name=region)

    available_volumes = ec2.volumes.filter(Filters=[{
        'Name': 'status',
        'Values': ['available']
    }])
    return available_volumes


def get_metrics(volume_id):
    """Get volume idle time on an individual volume over `start_date`
i18n.set('skip_locale_root_data', True)
i18n.set('filename_format', '{locale}.{format}')
i18n.set('load_path',
         [os.path.join(os.path.dirname(__file__), 'translations/')])

USE_L10N = True

USE_TZ = True

# Logging
AWS_ACCESS_KEY_ID = FI_AWS_CLOUDWATCH_ACCESS_KEY
AWS_SECRET_ACCESS_KEY = FI_AWS_CLOUDWATCH_SECRET_KEY  # noqa
AWS_REGION_NAME = 'us-east-1'

BOTO3_SESSION = Session(aws_access_key_id=AWS_ACCESS_KEY_ID,
                        aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
                        region_name=AWS_REGION_NAME)
LOGGING = {
    'version': 1,
    'disable_existing_loggers': False,
    'filters': {
        'require_debug_false': {
            '()': 'django.utils.log.RequireDebugFalse'
        }
    },
    'handlers': {
        'console': {
            'level': 'ERROR',
            'class': 'logging.StreamHandler'
        },
        'watchtower': {
Example #58
0
import boto3
from boto3.session import Session

running_instances_list = []
session = Session(aws_access_key_id='xxxxx',
                  aws_secret_access_key='xxxxxxx',
                  region_name='us-east-1')

ec2 = session.resource('ec2')

filters = [{
    'Name': 'tag:AutoOff',
    'Values': ['True']
}, {
    'Name': 'instance-state-name',
    'Values': ['running']
}]

instances = ec2.instances.filter(Filters=filters)

for instance in instances:
    running_instances_list.append(instance.id)

ec2.instances.filter(InstanceIds=running_instances_list).stop()
print "Stopped instances:", running_instances_list
Example #59
0
def get_resource(access_key_id, secret_key):
    session = Session(access_key_id, secret_key)
    s3 = session.resource('s3')
    return s3
Example #60
0
from __future__ import (absolute_import, print_function, unicode_literals)
from acli.services.s3 import (s3_list)
from acli.config import Config
from moto import mock_s3

import pytest
from boto3.session import Session
session = Session(region_name="eu-west-1")


@pytest.yield_fixture(scope='function')
def s3_bucket():
    """S3 mock service"""
    mock = mock_s3()
    mock.start()
    s3_client = session.client('s3')
    s3_client.create_bucket(Bucket='test_bucket_1')
    yield s3_client.list_buckets()
    mock.stop()


config = Config(cli_args={'--region': 'eu-west-1',
                          '--access_key_id': 'AKIAIOSFODNN7EXAMPLE',
                          '--secret_access_key': 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY'})


def test_s3_list_service(s3_bucket):
    with pytest.raises(SystemExit):
        assert s3_list(aws_config=config)