Ejemplo n.º 1
0
def _do_start_ssl_proxy_with_client_auth(port: int, target: str,
                                         client_cert_key: Tuple[str, str]):
    base_url = f"{'https://' if '://' not in target else ''}{target.rstrip('/')}"

    # prepare cert files (TODO: check whether/how we can pass cert strings to requests.request(..) directly)
    cert_file = client_cert_key[0]
    if not os.path.exists(cert_file):
        cert_file = new_tmp_file()
        save_file(cert_file, client_cert_key[0])
    key_file = client_cert_key[1]
    if not os.path.exists(key_file):
        key_file = new_tmp_file()
        save_file(key_file, client_cert_key[1])
    cert_params = (cert_file, key_file)

    # define forwarding listener
    class Listener(ProxyListener):
        def forward_request(self, method, path, data, headers):
            url = f"{base_url}{path}"
            result = requests.request(method=method,
                                      url=url,
                                      data=data,
                                      headers=headers,
                                      cert=cert_params,
                                      verify=False)
            return result

    proxy_thread = start_proxy_server(port,
                                      update_listener=Listener(),
                                      use_ssl=True)
    return proxy_thread
Ejemplo n.º 2
0
def download_and_extract(archive_url,
                         target_dir,
                         retries=0,
                         sleep=3,
                         tmp_archive=None):
    mkdir(target_dir)

    tmp_archive = tmp_archive or new_tmp_file()
    if not os.path.exists(tmp_archive):
        # create temporary placeholder file, to avoid duplicate parallel downloads
        save_file(tmp_archive, '')
        for i in range(retries + 1):
            try:
                download(archive_url, tmp_archive)
                break
            except Exception:
                time.sleep(sleep)

    _, ext = os.path.splitext(tmp_archive)
    if ext == '.zip':
        unzip(tmp_archive, target_dir)
    elif ext == '.gz' or ext == '.bz2':
        untar(tmp_archive, target_dir)
    else:
        raise Exception('Unsupported archive format: %s' % ext)
Ejemplo n.º 3
0
    def test_s3_event_notification_with_sqs(self):
        key_by_path = 'aws/bucket=2020/test1.txt'

        queue_url, queue_attributes = self._create_test_queue()
        self._create_test_notification_bucket(queue_attributes)
        self.s3_client.put_bucket_versioning(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                             VersioningConfiguration={'Status': 'Enabled'})

        body = 'Lorem ipsum dolor sit amet, ... ' * 30

        # put an object
        self.s3_client.put_object(Bucket=TEST_BUCKET_WITH_NOTIFICATION, Key=key_by_path, Body=body)

        self.assertEqual(self._get_test_queue_message_count(queue_url), '1')

        rs = self.sqs_client.receive_message(QueueUrl=queue_url)
        record = [json.loads(to_str(m['Body'])) for m in rs['Messages']][0]['Records'][0]

        download_file = new_tmp_file()
        self.s3_client.download_file(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                     Key=key_by_path, Filename=download_file)

        self.assertEqual(record['s3']['object']['size'], os.path.getsize(download_file))

        # clean up
        self.s3_client.put_bucket_versioning(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                             VersioningConfiguration={'Status': 'Disabled'})

        self.sqs_client.delete_queue(QueueUrl=queue_url)
        self._delete_bucket(TEST_BUCKET_WITH_NOTIFICATION, [key_by_path])
Ejemplo n.º 4
0
    def test_s3_upload_fileobj_with_large_file_notification(self):
        queue_url, queue_attributes = self._create_test_queue()
        self._create_test_notification_bucket(queue_attributes)

        # has to be larger than 64MB to be broken up into a multipart upload
        file_size = 75000000
        large_file = self.generate_large_file(file_size)
        download_file = new_tmp_file()
        try:
            self.s3_client.upload_file(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                       Key=large_file.name, Filename=large_file.name)

            self.assertEqual(self._get_test_queue_message_count(queue_url), '1')

            # ensure that the first message's eventName is ObjectCreated:CompleteMultipartUpload
            messages = self.sqs_client.receive_message(QueueUrl=queue_url, AttributeNames=['All'])
            message = json.loads(messages['Messages'][0]['Body'])
            self.assertEqual(message['Records'][0]['eventName'], 'ObjectCreated:CompleteMultipartUpload')

            # download the file, check file size
            self.s3_client.download_file(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                         Key=large_file.name, Filename=download_file)
            self.assertEqual(os.path.getsize(download_file), file_size)

            # clean up
            self.sqs_client.delete_queue(QueueUrl=queue_url)
            self._delete_bucket(TEST_BUCKET_WITH_NOTIFICATION, large_file.name)
        finally:
            # clean up large files
            large_file.close()
            rm_rf(large_file.name)
            rm_rf(download_file)
Ejemplo n.º 5
0
def get_java_handler(zip_file_content, handler, main_file):
    """Creates a Java handler from an uploaded ZIP or JAR.

    :type zip_file_content: bytes
    :param zip_file_content: ZIP file bytes.
    :type handler: str
    :param handler: The lambda handler path.
    :type main_file: str
    :param main_file: Filepath to the uploaded ZIP or JAR file.

    :returns: function or flask.Response
    """
    if not is_jar_archive(zip_file_content):
        with zipfile.ZipFile(BytesIO(zip_file_content)) as zip_ref:
            jar_entries = [e for e in zip_ref.infolist() if e.filename.endswith('.jar')]
            if len(jar_entries) != 1:
                raise Exception('Expected exactly one *.jar entry in zip file, found %s' % len(jar_entries))
            zip_file_content = zip_ref.read(jar_entries[0].filename)
            LOG.info('Found jar file %s with %s bytes in Lambda zip archive' %
                     (jar_entries[0].filename, len(zip_file_content)))
            main_file = new_tmp_file()
            save_file(main_file, zip_file_content)
    if is_jar_archive(zip_file_content):
        def execute(event, context):
            result, log_output = lambda_executors.EXECUTOR_LOCAL.execute_java_lambda(
                event, context, handler=handler, main_file=main_file)
            return result
        return execute
    return error_response(
        'Unable to extract Java Lambda handler - file is not a valid zip/jar files', 400, error_type='ValidationError')
Ejemplo n.º 6
0
def download_and_extract(archive_url,
                         target_dir,
                         retries=0,
                         sleep=3,
                         tmp_archive=None):
    mkdir(target_dir)

    if tmp_archive:
        _, ext = os.path.splitext(tmp_archive)
    else:
        _, ext = os.path.splitext(archive_url)

    tmp_archive = tmp_archive or new_tmp_file()
    if not os.path.exists(tmp_archive) or os.path.getsize(tmp_archive) <= 0:
        # create temporary placeholder file, to avoid duplicate parallel downloads
        save_file(tmp_archive, "")
        for i in range(retries + 1):
            try:
                download(archive_url, tmp_archive)
                break
            except Exception:
                time.sleep(sleep)

    if ext == ".zip":
        unzip(tmp_archive, target_dir)
    elif ext == ".gz" or ext == ".bz2":
        untar(tmp_archive, target_dir)
    else:
        raise Exception("Unsupported archive format: %s" % ext)
Ejemplo n.º 7
0
def run_cached(cmd, cache_duration_secs=None):
    if cache_duration_secs is None:
        cache_duration_secs = AWS_CACHE_TIMEOUT
    env_vars = os.environ.copy()
    env_vars.update({
        'AWS_ACCESS_KEY_ID':
        os.environ.get('AWS_ACCESS_KEY_ID') or 'foobar',
        'AWS_SECRET_ACCESS_KEY':
        os.environ.get('AWS_SECRET_ACCESS_KEY') or 'foobar',
        'AWS_DEFAULT_REGION':
        os.environ.get('AWS_DEFAULT_REGION') or DEFAULT_REGION,
        'PYTHONWARNINGS':
        'ignore:Unverified HTTPS request'
    })
    tmp_file_path = new_tmp_file()
    error = None
    with open(tmp_file_path, 'w') as err_file:
        try:
            return run(cmd,
                       cache_duration_secs=cache_duration_secs,
                       env_vars=env_vars,
                       stderr=err_file)
        except Exception as e:
            error = e
    if error:
        LOG.warning('Error running command: %s %s %s' %
                    (cmd, error, load_file(tmp_file_path)))
        raise error
Ejemplo n.º 8
0
def test_parse_config_file(input_type, sections):
    config_string = CONFIG_FILE_SECTION.lstrip()

    # generate config string
    if sections == 0:
        config_string = config_string.partition("\n")[2]
    config_string = "\n".join([
        config_string.replace("{section}", str(i))
        for i in range(max(sections, 1))
    ])

    # parse result
    config_input = config_string
    if input_type == "file":
        config_input = new_tmp_file()  # deleted on shutdown
        save_file(config_input, config_string)
    result = parse_config_file(config_input)

    # run assertions
    expected = {
        "var1": "foo bar 123",
        "var2": "123.45",
        "var3": "Test string' <with% special { chars!",
    }
    if sections <= 1:
        assert expected == result
    else:
        assert sections == len(result)
        for section in result.values():
            assert expected == section
Ejemplo n.º 9
0
def test_train_tensorflow():

    sagemaker_client = aws_stack.connect_to_service('sagemaker')
    iam_client = aws_stack.connect_to_service('iam')
    sagemaker_session = sagemaker.Session(boto_session=aws_stack.Boto3Session(),
        sagemaker_client=sagemaker_client)

    try:
        response = iam_client.create_role(RoleName='r1', AssumeRolePolicyDocument='{}')
    except Exception:
        response = iam_client.get_role(RoleName='r1')
    role_arn = response['Role']['Arn']
    test_data = 'testdata'

    if not os.path.exists(test_data):
        data_sets = input_data.read_data_sets(test_data,
            dtype=tf.uint8, reshape=False, validation_size=5000)
        convert_to(data_sets.train, 'train', test_data)
        convert_to(data_sets.validation, 'validation', test_data)
        convert_to(data_sets.test, 'test', test_data)

    inputs = sagemaker_session.upload_data(path=test_data, key_prefix='data/mnist')

    tmp_file = new_tmp_file()
    download(TF_MNIST_URL, tmp_file)
    mnist_estimator = TensorFlow(entry_point=tmp_file, role=role_arn, framework_version='1.12.0',
        training_steps=10, evaluation_steps=10, sagemaker_session=sagemaker_session,
        train_instance_count=1, train_instance_type='local')
    mnist_estimator.fit(inputs, logs=False)
Ejemplo n.º 10
0
    def test_s3_upload_fileobj_with_large_file_notification(self):
        # create test queue
        queue_url = self.sqs_client.create_queue(
            QueueName=TEST_QUEUE_FOR_BUCKET_WITH_NOTIFICATION)['QueueUrl']
        queue_attributes = self.sqs_client.get_queue_attributes(
            QueueUrl=queue_url, AttributeNames=['QueueArn'])

        # create test bucket
        self.s3_client.create_bucket(Bucket=TEST_BUCKET_WITH_NOTIFICATION)
        self.s3_client.put_bucket_notification_configuration(
            Bucket=TEST_BUCKET_WITH_NOTIFICATION,
            NotificationConfiguration={
                'QueueConfigurations': [{
                    'QueueArn':
                    queue_attributes['Attributes']['QueueArn'],
                    'Events': ['s3:ObjectCreated:*']
                }]
            })

        # has to be larger than 64MB to be broken up into a multipart upload
        file_size = 75000000
        large_file = self.generate_large_file(file_size)
        download_file = new_tmp_file()
        try:
            self.s3_client.upload_file(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                       Key=large_file.name,
                                       Filename=large_file.name)
            queue_attributes = self.sqs_client.get_queue_attributes(
                QueueUrl=queue_url,
                AttributeNames=['ApproximateNumberOfMessages'])
            message_count = queue_attributes['Attributes'][
                'ApproximateNumberOfMessages']
            # the ApproximateNumberOfMessages attribute is a string
            self.assertEqual(message_count, '1')

            # ensure that the first message's eventName is ObjectCreated:CompleteMultipartUpload
            messages = self.sqs_client.receive_message(QueueUrl=queue_url,
                                                       AttributeNames=['All'])
            message = json.loads(messages['Messages'][0]['Body'])
            self.assertEqual(message['Records'][0]['eventName'],
                             'ObjectCreated:CompleteMultipartUpload')

            # download the file, check file size
            self.s3_client.download_file(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                         Key=large_file.name,
                                         Filename=download_file)
            self.assertEqual(os.path.getsize(download_file), file_size)

            # clean up
            self.sqs_client.delete_queue(QueueUrl=queue_url)
            self.s3_client.delete_object(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                         Key=large_file.name)
            self.s3_client.delete_bucket(Bucket=TEST_BUCKET_WITH_NOTIFICATION)
        finally:
            # clean up large files
            large_file.close()
            rm_rf(large_file.name)
            rm_rf(download_file)
Ejemplo n.º 11
0
def _do_start_ssl_proxy_with_client_auth(port: int, target: PortOrUrl,
                                         client_cert_key: Tuple[str, str]):
    # prepare cert files (TODO: check whether/how we can pass cert strings to requests.request(..) directly)
    cert_file = client_cert_key[0]
    if not os.path.exists(cert_file):
        cert_file = new_tmp_file()
        save_file(cert_file, client_cert_key[0])
    key_file = client_cert_key[1]
    if not os.path.exists(key_file):
        key_file = new_tmp_file()
        save_file(key_file, client_cert_key[1])
    cert_params = (cert_file, key_file)

    # start proxy
    requests_kwargs = {"cert": cert_params}
    result = _do_start_ssl_proxy_with_listener(port,
                                               target,
                                               requests_kwargs=requests_kwargs)
    return result
Ejemplo n.º 12
0
    def setUpClass(cls):
        # Note: create scheduled Lambda here - assertions will be run in test_scheduled_lambda() below..

        # create test Lambda
        cls.scheduled_lambda_name = 'scheduled-%s' % short_uid()
        handler_file = new_tmp_file()
        save_file(handler_file, TEST_HANDLER)
        resp = testutil.create_lambda_function(handler_file=handler_file, func_name=cls.scheduled_lambda_name)
        func_arn = resp['CreateFunctionResponse']['FunctionArn']

        # create scheduled Lambda function
        rule_name = 'rule-%s' % short_uid()
        events = aws_stack.connect_to_service('events')
        events.put_rule(Name=rule_name, ScheduleExpression='rate(1 minutes)')
        events.put_targets(Rule=rule_name, Targets=[{'Id': 'target-%s' % short_uid(), 'Arn': func_arn}])
Ejemplo n.º 13
0
    def test_s3_event_notification_with_sqs(self):
        key_by_path = 'aws/bucket=2020/test1.txt'

        queue_url, queue_attributes = self._create_test_queue()
        self._create_test_notification_bucket(queue_attributes)
        self.s3_client.put_bucket_versioning(
            Bucket=TEST_BUCKET_WITH_NOTIFICATION,
            VersioningConfiguration={'Status': 'Enabled'})

        # flake8: noqa: W291
        body = """ Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Aenean commodo ligula eget dolor. 
        Aenean massa. Cum sociis natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. 
        Donec quam felis, ultricies nec, pellentesque eu, pretium quis, sem. Nulla consequat massa quis enim. 
        Donec pede justo, fringilla vel, aliquet nec, vulputate eget, arcu. In enim justo, rhoncus ut, imperdiet a,. 
        Nullam dictum felis eu pede mollis pretium. Integer tincidunt. Cras dapibus. Vivamus elementum semper nisi. 
        Aenean vulputate eleifend tellus. Aenean leo ligula, porttitor eu, consequat vitae, eleifend ac, enim. 
        Aliquam lorem ante, dapibus in, viverra quis, feugiat a, tellus. Phasellus viverra nulla ut metus varius laoreet. 
        Quisque rutrum. Aenean imperdiet. Etiam ultricies nisi vel augue. Curabitur ullamcorper ultricies nisi. Nam eget dui. 
        Etiam rhoncus. Maecenas tempus, tellus eget condimentum rhoncus, sem quam semper libero, sed ipsum. """

        # put an object
        self.s3_client.put_object(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                  Key=key_by_path,
                                  Body=body)

        self.assertEqual(self._get_test_queue_message_count(queue_url), '1')

        rs = self.sqs_client.receive_message(QueueUrl=queue_url)
        record = [json.loads(to_str(m['Body']))
                  for m in rs['Messages']][0]['Records'][0]

        download_file = new_tmp_file()
        self.s3_client.download_file(Bucket=TEST_BUCKET_WITH_NOTIFICATION,
                                     Key=key_by_path,
                                     Filename=download_file)

        self.assertEqual(record['s3']['object']['size'],
                         os.path.getsize(download_file))

        # clean up
        self.s3_client.put_bucket_versioning(
            Bucket=TEST_BUCKET_WITH_NOTIFICATION,
            VersioningConfiguration={'Status': 'Disabled'})

        self.sqs_client.delete_queue(QueueUrl=queue_url)
        self._delete_bucket(TEST_BUCKET_WITH_NOTIFICATION, [key_by_path])
Ejemplo n.º 14
0
def test_download_with_timeout():
    class DownloadListener(ProxyListener):
        def forward_request(self, method, path, data, headers):
            if path == "/sleep":
                time.sleep(2)
            return {}

    port = get_free_tcp_port()
    proxy = start_proxy_server(port, update_listener=DownloadListener())

    tmp_file = new_tmp_file()
    download(f"http://localhost:{port}/", tmp_file)
    assert load_file(tmp_file) == "{}"
    with pytest.raises(TimeoutError):
        download(f"http://localhost:{port}/sleep", tmp_file, timeout=1)

    # clean up
    proxy.stop()
    rm_rf(tmp_file)
Ejemplo n.º 15
0
    def setUpClass(cls):
        # Note: create scheduled Lambda here - assertions will be run in test_scheduled_lambda() below..

        # create test Lambda
        cls.scheduled_lambda_name = "scheduled-%s" % short_uid()
        handler_file = new_tmp_file()
        save_file(handler_file, TEST_HANDLER)
        resp = testutil.create_lambda_function(
            handler_file=handler_file, func_name=cls.scheduled_lambda_name
        )
        func_arn = resp["CreateFunctionResponse"]["FunctionArn"]

        # create scheduled Lambda function
        rule_name = "rule-%s" % short_uid()
        events = aws_stack.connect_to_service("events")
        events.put_rule(Name=rule_name, ScheduleExpression="rate(1 minutes)")
        events.put_targets(
            Rule=rule_name, Targets=[{"Id": "target-%s" % short_uid(), "Arn": func_arn}]
        )
Ejemplo n.º 16
0
def test_generate_ssl_cert():
    def _assert(cert, key):
        # assert that file markers are in place
        assert PEM_CERT_START in cert
        assert PEM_CERT_END in cert
        assert re.match(PEM_KEY_START_REGEX, key.replace("\n", " "))
        assert re.match(fr".*{PEM_KEY_END_REGEX}", key.replace("\n", " "))

    # generate cert and get content directly
    cert = generate_ssl_cert()
    _assert(cert, cert)

    # generate cert to file and load content from there
    target_file, cert_file_name, key_file_name = generate_ssl_cert(
        target_file=new_tmp_file(), overwrite=True)
    _assert(load_file(cert_file_name), load_file(key_file_name))

    # clean up
    rm_rf(cert_file_name)
    rm_rf(key_file_name)