예제 #1
0
    def test_watch_collision_when_missing(self, mock_logger):
        correlation_id = uuid.uuid4().hex

        # init aws redis connection
        aws.get_connection(self.mock_settings.PRIMARY_CACHE_SOURCE)

        def inline(a, b, c, d, ac):
            r = aws.original_serialize_lease_value(a, b, c, d)
            ac.setex('lease-' + correlation_id, 3600, '99:99:99:99')
            return r

        try:
            aws.original_serialize_lease_value = aws._serialize_lease_value
            another_connection = redis.StrictRedis(host='localhost',
                                                   port=6379,
                                                   db=0)
            aws._serialize_lease_value = lambda a, b, c, d: inline(
                a, b, c, d, another_connection)
            acquired = aws.acquire_lease(correlation_id, 1, 1, primary=True)
            self.assertFalse(acquired)
            self.assertEqual(
                mock.call.warn(
                    "Cannot acquire redis lease: unexpectedly lost 'pipe.watch' race"
                ), mock_logger.mock_calls[-1])

        finally:
            aws._serialize_lease_value = aws._serialize_lease_value
            del aws.original_serialize_lease_value
 def create_cache(self):
     try:
         dynamodb_conn = aws.get_connection(DYNAMODB_CACHE_SOURCE)
         dynamodb_table = aws.get_arn_from_arn_string(DYNAMODB_CACHE_SOURCE).slash_resource()
         dynamodb_conn.create_table(
             TableName=dynamodb_table,
             AttributeDefinitions=[
                 {
                     AWS_DYNAMODB.AttributeName: CACHE_DATA.KEY,
                     AWS_DYNAMODB.AttributeType: AWS_DYNAMODB.STRING
                 }
             ],
             KeySchema=[
                 {
                     AWS_DYNAMODB.AttributeName: CACHE_DATA.KEY,
                     AWS_DYNAMODB.KeyType: AWS_DYNAMODB.HASH
                 }
             ],
             ProvisionedThroughput={
                 AWS_DYNAMODB.ReadCapacityUnits: 10,
                 AWS_DYNAMODB.WriteCapacityUnites: 10
             }
         )
     except:
         pass
예제 #3
0
    def execute(self, context, obj):
        logging.info('context: %s', context)

        # randomly raise an exception
        if random.uniform(0, 1.0) < 0.5:
            raise Exception()

        logging.info('action.name=%s', self.name)

        # increment the counter
        context['count'] = context.get('count', 0) + 1

        # set the started_at (user space) variable
        if context['count'] == 1:
            context['started_at'] = int(time.time())

        # when done, emit a dynamodb record
        if context['count'] > 100:
            if 'results_arn' in context:
                table_arn = context['results_arn']
                table_name = get_arn_from_arn_string(table_arn).slash_resource()
                conn = get_connection(table_arn)
                conn.put_item(
                    TableName=table_name,
                    Item={
                        'correlation_id': {AWS_DYNAMODB.STRING: context.correlation_id},
                        'count': {AWS_DYNAMODB.NUMBER: str(context['count'])},
                        'started_at': {AWS_DYNAMODB.NUMBER: str(context['started_at'])},
                        'finished_at': {AWS_DYNAMODB.NUMBER: str(int(time.time()))},
                        'flag': {AWS_DYNAMODB.STRING: context.get('flag', 'Unknown')}
                    }
                )
            return 'done'

        return 'event1'
예제 #4
0
    def test_add_collision_when_missing(self, mock_logger):
        correlation_id = uuid.uuid4().hex

        # init aws memcache connection
        connection = aws.get_connection(
            self.mock_settings.PRIMARY_CACHE_SOURCE)

        def inline(x, c, ac):
            y = c.original_gets(x)
            ac.add('lease-' + correlation_id, '99:99:99:99')
            return y

        try:
            connection.original_gets = connection.gets
            another_connection = memcache.Client(['localhost:11211'],
                                                 cache_cas=True)
            connection.gets = lambda x: inline(x, connection,
                                               another_connection)
            acquired = aws.acquire_lease(correlation_id, 1, 1, primary=True)
            self.assertFalse(acquired)
            self.assertEqual(
                mock.call.warn(
                    "Cannot acquire memcache lease: unexpectedly lost 'memcache.add' race"
                ), mock_logger.mock_calls[-1])

        finally:
            connection.gets = connection.original_gets
            del connection.original_gets
    level=int(args.log_level) if args.log_level.isdigit() else args.log_level,
    datefmt='%Y-%m-%d %H:%M:%S')

logging.getLogger('boto3').setLevel(args.boto_log_level)
logging.getLogger('botocore').setLevel(args.boto_log_level)

validate_config()

# setup connections to AWS
dynamodb_table_arn = getattr(settings, args.dynamodb_table_arn)
logging.info('DynamoDB table ARN: %s', dynamodb_table_arn)
logging.info('DynamoDB endpoint: %s', settings.ENDPOINTS.get(AWS.DYNAMODB))
if get_arn_from_arn_string(dynamodb_table_arn).service != AWS.DYNAMODB:
    logging.fatal("%s is not a DynamoDB ARN", dynamodb_table_arn)
    sys.exit(1)
dynamodb_conn = get_connection(dynamodb_table_arn, disable_chaos=True)
dynamodb_table = get_arn_from_arn_string(dynamodb_table_arn).slash_resource()
logging.info('DynamoDB table: %s', dynamodb_table)

if 'RESULTS' in args.dynamodb_table_arn:
    # create a dynamodb table for examples/tracer
    response = dynamodb_conn.create_table(
        TableName=dynamodb_table,
        AttributeDefinitions=[
            {
                AWS_DYNAMODB.AttributeName: 'correlation_id',
                AWS_DYNAMODB.AttributeType: AWS_DYNAMODB.STRING
            },
        ],
        KeySchema=[{
            AWS_DYNAMODB.AttributeName: 'correlation_id',
    start_state_machines(args.machine_name, [context] * args.num_machines,
                         current_state=current_state,
                         current_event=current_event)
    exit(0)

# checkpoint specified, so start with a context saved to the kinesis stream
if args.checkpoint_shard_id and args.checkpoint_sequence_number:

    # setup connections to AWS
    kinesis_stream_arn = getattr(settings, args.kinesis_stream_arn)
    logging.info('Kinesis stream ARN: %s', kinesis_stream_arn)
    logging.info('Kinesis endpoint: %s', settings.ENDPOINTS.get(AWS.KINESIS))
    if get_arn_from_arn_string(kinesis_stream_arn).service != AWS.KINESIS:
        logging.fatal("%s is not a Kinesis ARN", kinesis_stream_arn)
        sys.exit(1)
    kinesis_conn = get_connection(kinesis_stream_arn)
    kinesis_stream = get_arn_from_arn_string(
        kinesis_stream_arn).slash_resource()
    logging.info('Kinesis stream: %s', kinesis_stream)

    # create a shard iterator for the specified shard and sequence number
    shard_iterator = kinesis_conn.get_shard_iterator(
        StreamName=kinesis_stream,
        ShardId=args.checkpoint_shard_id,
        ShardIteratorType=AWS_KINESIS.AT_SEQUENCE_NUMBER,
        StartingSequenceNumber=args.checkpoint_sequence_number)[
            AWS_KINESIS.ShardIterator]

    # get the record that has the last successful state
    records = kinesis_conn.get_records(ShardIterator=shard_iterator, Limit=1)
    if records:
import settings  # noqa: E402

random.seed(args.random_seed)
STARTED_AT = str(int(time.time()))

validate_config()

# setup connections to AWS
if args.run_kinesis_lambda:
    kinesis_stream_arn = getattr(settings, args.kinesis_stream_arn)
    logging.info('Kinesis stream ARN: %s', kinesis_stream_arn)
    logging.info('Kinesis endpoint: %s', settings.ENDPOINTS.get(AWS.KINESIS))
    if get_arn_from_arn_string(kinesis_stream_arn).service != AWS.KINESIS:
        logging.fatal("%s is not a Kinesis ARN", kinesis_stream_arn)
        sys.exit(1)
    kinesis_conn = get_connection(kinesis_stream_arn, disable_chaos=True)
    kinesis_stream = get_arn_from_arn_string(
        kinesis_stream_arn).slash_resource()
    logging.info('Kinesis stream: %s', kinesis_stream)

if args.run_sqs_lambda:
    sqs_queue_arn = getattr(settings, args.sqs_queue_arn)
    logging.info('SQS queue ARN: %s', sqs_queue_arn)
    logging.info('SQS endpoint: %s', settings.ENDPOINTS.get(AWS.SQS))
    if get_arn_from_arn_string(sqs_queue_arn).service != AWS.SQS:
        logging.fatal("%s is not a SQS ARN", sqs_queue_arn)
        sys.exit(1)
    sqs_conn = get_connection(sqs_queue_arn, disable_chaos=True)
    sqs_queue = get_arn_from_arn_string(sqs_queue_arn).colon_resource()
    sqs_queue_url = _get_sqs_queue_url(sqs_queue_arn)
    logging.info('SQS queue: %s', sqs_queue)
예제 #8
0
    def execute(self, context, obj):
        """
        Action that launches and ECS task.

        The API for using this class is as follows:

        {
           'context_var': 'context_value',              # normal context variable
           'task_details': {                            # dictionary of all the states that run images
              'state_name_1': {                         # first state name (as in fsm.yaml)
                                                        # cluster to run image for state_name_1
                'cluster_arn': 'arn:aws:ecs:region:1234567890:cluster/foobar',
                'container_image': 'host/corp/image:12345' # image for state_name_1
              },
              'state_name_2': {                         # second state name (as in fsm.yaml)
                'cluster_arn': 'arn:aws:ecs:eu-west-1:1234567890:cluster/foobar',
                'container_image': 'host/corp/image:12345',
                'runner_task_definition': 'my_runner',  # alternative docker image runner task name
                'runner_container_name': 'my_runner'    # alternative docker image runner container name
              }
            },
            'clone_aws_credentials': True               # flag to copy aws creds from local environment
                                                        # to the container overrides - makes for easier
                                                        # local testing. alternatively, just add permanent
                                                        # credentials to your runner task.
        }

        :param context: a aws_lambda_fsm.fsm.Context instance
        :param obj: a dict
        :return: a string event, or None
        """

        # construct a version of the context that can be base64 encoded
        # and stuffed into a environment variable for the container program.
        # all the container program needs to do is extract this data, add
        # an event, and send the message onto sqs/kinesis/... since this is an
        # ENTRY action, we inspect the current transition for the state we
        # will be in AFTER this code executes.
        ctx = Context.from_payload_dict(context.to_payload_dict())
        ctx.current_state = context.current_transition.target
        ctx.steps += 1
        fsm_context = base64.b64encode(
            json.dumps(ctx.to_payload_dict(),
                       **json_dumps_additional_kwargs()))

        # now finally launch the ECS task using all the data from above
        # as well as tasks etc. specified when the state machine was run.
        state_to_task_details_map = context[TASK_DETAILS_KEY]
        task_details = state_to_task_details_map[
            context.current_transition.target.name]

        # this is the image the user wants to run
        cluster_arn = task_details[CLUSTER_ARN_KEY]
        container_image = task_details[CONTAINER_IMAGE_KEY]

        # this is the task that will run that image
        task_definition = task_details.get(RUNNER_TASK_DEFINITION_KEY,
                                           DEFAULT_RUNNER_TASK_NAME)
        container_name = task_details.get(RUNNER_CONTAINER_NAME_KEY,
                                          DEFAULT_RUNNER_CONTAINER_NAME)

        # setup the environment for the ECS task. this first set of variables
        # are used by the docker container runner image.
        environment = {
            ENVIRONMENT.FSM_CONTEXT: fsm_context,
            ENVIRONMENT.FSM_DOCKER_IMAGE: container_image
        }
        # this second set of variables are used by actual docker image that
        # does actual stuff (pdf processing etc.)
        for name, value in task_details.get(ENVIRONMENT_KEY, {}).items():
            environment[name] = value

        # store the environment and record the guid.
        guid, _ = store_environment(context, environment)

        # stuff the guid and a couple stream settings into the task
        # overrides. the guid allows the FSM_CONTEXT to be loaded from
        # storage, and the FSM_PRIMARY_STREAM_SOURCE allow the call
        # to send_next_event_for_dispatch call to succeed.
        env = [{
            AWS_ECS.CONTAINER_OVERRIDES.ENVIRONMENT.NAME:
            ENVIRONMENT.FSM_ENVIRONMENT_GUID_KEY,
            AWS_ECS.CONTAINER_OVERRIDES.ENVIRONMENT.VALUE: guid
        }, {
            AWS_ECS.CONTAINER_OVERRIDES.ENVIRONMENT.NAME:
            ENVIRONMENT.FSM_PRIMARY_STREAM_SOURCE,
            AWS_ECS.CONTAINER_OVERRIDES.ENVIRONMENT.VALUE:
            get_primary_stream_source() or ''
        }, {
            AWS_ECS.CONTAINER_OVERRIDES.ENVIRONMENT.NAME:
            ENVIRONMENT.FSM_SECONDARY_STREAM_SOURCE,
            AWS_ECS.CONTAINER_OVERRIDES.ENVIRONMENT.VALUE:
            get_secondary_stream_source() or ''
        }]

        # this is for local testing
        if context.get(CLONE_AWS_CREDENTIALS_KEY):
            _testing(env)

        # get an ECS connection and start a task.
        conn = get_connection(cluster_arn)

        # run the task
        conn.run_task(cluster=cluster_arn,
                      taskDefinition=task_definition,
                      overrides={
                          AWS_ECS.CONTAINER_OVERRIDES.KEY: [{
                              AWS_ECS.CONTAINER_OVERRIDES.CONTAINER_NAME:
                              container_name,
                              AWS_ECS.CONTAINER_OVERRIDES.ENVIRONMENT.KEY:
                              env
                          }]
                      })

        # entry actions do not return events
        return None
예제 #9
0
parser = argparse.ArgumentParser(description='Creates AWS SNS topics.')
parser.add_argument('--sns_topic_arn', default='PRIMARY_STREAM_SOURCE')
parser.add_argument('--log_level', default='INFO')
parser.add_argument('--boto_log_level', default='INFO')
args = parser.parse_args()

logging.basicConfig(
    format='[%(levelname)s] %(asctime)-15s %(message)s',
    level=int(args.log_level) if args.log_level.isdigit() else args.log_level,
    datefmt='%Y-%m-%d %H:%M:%S')

logging.getLogger('boto3').setLevel(args.boto_log_level)
logging.getLogger('botocore').setLevel(args.boto_log_level)

validate_config()

# setup connections to AWS
sns_topic_arn = getattr(settings, args.sns_topic_arn)
logging.info('SNS topic ARN: %s', sns_topic_arn)
logging.info('SNS endpoint: %s', settings.ENDPOINTS.get(AWS.SNS))
if get_arn_from_arn_string(sns_topic_arn).service != AWS.SNS:
    logging.fatal("%s is not an SNS ARN", sns_topic_arn)
    sys.exit(1)
sns_conn = get_connection(sns_topic_arn, disable_chaos=True)
sns_topic = get_arn_from_arn_string(sns_topic_arn).resource
logging.info('SNS topic: %s', sns_topic)

# configure the topic
response = sns_conn.create_topic(Name=sns_topic)
logging.info(response)
    format='[%(levelname)s] %(asctime)-15s %(message)s',
    level=int(args.log_level) if args.log_level.isdigit() else args.log_level,
    datefmt='%Y-%m-%d %H:%M:%S')

logging.getLogger('boto3').setLevel(args.boto_log_level)
logging.getLogger('botocore').setLevel(args.boto_log_level)

validate_config()

# setup connections to AWS
sqs_arn_string = getattr(settings, args.sqs_queue_arn)
sqs_arn = get_arn_from_arn_string(sqs_arn_string)
if sqs_arn.service != AWS.SQS:
    logging.fatal("%s is not an SQS ARN", sqs_arn_string)
    sys.exit(1)
sqs_conn = get_connection(sqs_arn_string, disable_chaos=True)
response = sqs_conn.get_queue_url(QueueName=sqs_arn.colon_resource())
sqs_queue_url = response[AWS_SQS.QueueUrl]

logging.info('SQS ARN: %s', sqs_arn_string)
logging.info('SQS endpoint: %s', settings.ENDPOINTS.get(AWS.SQS))
logging.info('SQS queue: %s', sqs_arn.resource)
logging.info('SQS queue url: %s', sqs_queue_url)

dest_arn_string = getattr(settings, args.dest_arn)
dest_arn = get_arn_from_arn_string(dest_arn_string)
if dest_arn.service not in ALLOWED_DEST_SERVICES:
    logging.fatal("%s is not a %s ARN", dest_arn_string,
                  '/'.join(map(str.upper, ALLOWED_DEST_SERVICES)))
    sys.exit(1)
dest_conn = get_connection(dest_arn_string, disable_chaos=True)
args = parser.parse_args()

logging.basicConfig(
    format='[%(levelname)s] %(asctime)-15s %(message)s',
    level=int(args.log_level) if args.log_level.isdigit() else args.log_level,
    datefmt='%Y-%m-%d %H:%M:%S'
)

logging.getLogger('boto3').setLevel(args.boto_log_level)
logging.getLogger('botocore').setLevel(args.boto_log_level)

validate_config()

# setup connections to AWS
kinesis_stream_arn = getattr(settings, args.kinesis_stream_arn)
logging.info('Kinesis stream ARN: %s', kinesis_stream_arn)
logging.info('Kinesis endpoint: %s', settings.ENDPOINTS.get(AWS.KINESIS))
if get_arn_from_arn_string(kinesis_stream_arn).service != AWS.KINESIS:
    logging.fatal("%s is not a Kinesis ARN", kinesis_stream_arn)
    sys.exit(1)
kinesis_conn = get_connection(kinesis_stream_arn, disable_chaos=True)
kinesis_stream = get_arn_from_arn_string(kinesis_stream_arn).slash_resource()
logging.info('Kinesis stream: %s', kinesis_stream)

# configure the stream
response = kinesis_conn.create_stream(
    StreamName=kinesis_stream,
    ShardCount=args.kinesis_num_shards
)
logging.info(response)
예제 #12
0
parser.add_argument('--boto_log_level', default='INFO')
args = parser.parse_args()

logging.basicConfig(
    format='[%(levelname)s] %(asctime)-15s %(message)s',
    level=int(args.log_level) if args.log_level.isdigit() else args.log_level,
    datefmt='%Y-%m-%d %H:%M:%S'
)

logging.getLogger('boto3').setLevel(args.boto_log_level)
logging.getLogger('botocore').setLevel(args.boto_log_level)

validate_config()

# setup connections to AWS
sqs_queue_arn = getattr(settings, args.sqs_queue_arn)
logging.info('SQS queue ARN: %s', sqs_queue_arn)
logging.info('SQS endpoint: %s', settings.ENDPOINTS.get(AWS.SQS))
if get_arn_from_arn_string(sqs_queue_arn).service != AWS.SQS:
    logging.fatal("%s is not an SQS ARN", sqs_queue_arn)
    sys.exit(1)
sqs_conn = get_connection(sqs_queue_arn, disable_chaos=True)
sqs_queue = get_arn_from_arn_string(sqs_queue_arn).resource
logging.info('SQS queue: %s', sqs_queue)

# configure the queue
response = sqs_conn.create_queue(
    QueueName=sqs_queue
)
logging.info(response)