Example #1
0
    def test_request_callback_signed_header(self, resource_group, location, storage_account, storage_account_key):
        # Arrange
        service = QueueServiceClient(self.account_url(storage_account, "queue"), credential=storage_account_key)
        name = self.get_resource_name('cont')

        # Act
        try:
            headers = {'x-ms-meta-hello': 'world'}
            queue = service.create_queue(name, headers=headers)

            # Assert
            metadata = queue.get_queue_properties().metadata
            self.assertEqual(metadata, {'hello': 'world'})
        finally:
            service.delete_queue(name)
Example #2
0
class AzureFunctionAppBackend:
    """
    A wrap-up around Azure Function Apps backend.
    """

    def __init__(self, config, storage_config):
        logger.debug("Creating Azure Functions client")
        self.name = 'azure_fa'
        self.azure_config = config
        self.resource_group = self.azure_config['resource_group']
        self.storage_account = self.azure_config['storage_account']
        self.account_key = self.azure_config['storage_account_key']
        self.location = self.azure_config['location']
        self.functions_version = self.azure_config['functions_version']

        self.queue_service_url = 'https://{}.queue.core.windows.net'.format(self.storage_account)
        self.queue_service = QueueServiceClient(account_url=self.queue_service_url,
                                                credential=self.account_key)

        msg = COMPUTE_CLI_MSG.format('Azure Functions')
        logger.info("{} - Location: {}".format(msg, self.location))

    def _format_action_name(self, runtime_name, runtime_memory=None):
        runtime_name = runtime_name.replace('/', '--').replace(':', '--')
        return runtime_name

    def _format_queue_name(self, action_name, q_type):
        runtime_name = action_name.replace('--', '-')
        return runtime_name+'-'+q_type

    def _get_default_runtime_image_name(self):
        py_version = version_str(sys.version_info).replace('.', '')
        revision = 'latest' if 'dev' in __version__ else __version__.replace('.', '')
        runtime_name = '{}-{}-v{}-{}'.format(self.storage_account, az_config.RUNTIME_NAME,
                                             py_version, revision)
        return runtime_name

    def create_runtime(self, docker_image_name, memory=None, timeout=az_config.RUNTIME_TIMEOUT):
        """
        Creates a new runtime into Azure Function Apps
        from the provided Linux image for consumption plan
        """
        default_runtime_img_name = self._get_default_runtime_image_name()
        if docker_image_name in ['default', default_runtime_img_name]:
            # We only build the default image. rest of images must already exist
            # in the docker registry.
            docker_image_name = default_runtime_img_name
            self._build_default_runtime(default_runtime_img_name)

        logger.info('Creating new Lithops runtime for Azure Function Apps')
        self._create_function(docker_image_name, memory, timeout)
        metadata = self._generate_runtime_meta(docker_image_name, memory)

        return metadata

    def _build_default_runtime(self, default_runtime_img_name):
        """
        Builds the default runtime
        """
        return self.build_runtime(default_runtime_img_name)

        if os.system('{} --version >{} 2>&1'.format(az_config.DOCKER_PATH, os.devnull)) == 0:
            # Build default runtime using local dokcer
            python_version = version_str(sys.version_info)
            dockerfile = "Dockefile.default-azure-runtime"
            with open(dockerfile, 'w') as f:
                f.write("FROM mcr.microsoft.com/azure-functions/python:3.0-python{}\n".format(python_version))
                f.write(az_config.DEFAULT_DOCKERFILE)
            self.build_runtime_docker(default_runtime_img_name, dockerfile)
            os.remove(dockerfile)
        else:
            raise Exception('docker command not found. Install docker or use '
                            'an already built runtime')

    def build_runtime(self, runtime_name, requirements_file=None):
        try:
            shutil.rmtree(az_config.BUILD_DIR)
        except Exception:
            pass

        action_name = self._format_action_name(runtime_name)

        build_dir = os.path.join(az_config.BUILD_DIR, action_name)
        os.makedirs(build_dir, exist_ok=True)

        logger.info('Building default runtime in {}'.format(build_dir))

        action_dir = os.path.join(build_dir, az_config.ACTION_DIR)
        os.makedirs(action_dir, exist_ok=True)

        req_file = os.path.join(build_dir, 'requirements.txt')
        with open(req_file, 'w') as reqf:
            reqf.write(az_config.REQUIREMENTS_FILE)

        host_file = os.path.join(build_dir, 'host.json')
        with open(host_file, 'w') as hstf:
            hstf.write(az_config.HOST_FILE)

        fn_file = os.path.join(action_dir, 'function.json')
        with open(fn_file, 'w') as fnf:
            in_q_name = self._format_queue_name(action_name, az_config.IN_QUEUE)
            az_config.BINDINGS['bindings'][0]['queueName'] = in_q_name
            out_q_name = self._format_queue_name(action_name, az_config.OUT_QUEUE)
            az_config.BINDINGS['bindings'][1]['queueName'] = out_q_name
            fnf.write(json.dumps(az_config.BINDINGS))

        entry_point = os.path.join(os.path.dirname(__file__), 'entry_point.py')
        main_file = os.path.join(action_dir, '__init__.py')
        shutil.copy(entry_point, main_file)

        mod_dir = os.path.join(build_dir, az_config.ACTION_MODULES_DIR)
        os.chdir(build_dir)
        cmd = 'pip3 install -U -t {} -r requirements.txt'.format(mod_dir)
        if logger.getEffectiveLevel() != logging.DEBUG:
            cmd = cmd + " >{} 2>&1".format(os.devnull)
        os.system(cmd)
        lithops_location = os.path.dirname(os.path.abspath(lithops.__file__))
        shutil.copytree(lithops_location, os.path.join(mod_dir, 'lithops'))

    def build_runtime_docker(self, docker_image_name, dockerfile):
        """
        Builds a new runtime from a Docker file and pushes it to the Docker hub
        """
        logger.debug('Building new docker image from Dockerfile')
        logger.debug('Docker image name: {}'.format(docker_image_name))

        entry_point = os.path.join(os.path.dirname(__file__), 'entry_point.py')
        create_handler_zip(az_config.FH_ZIP_LOCATION, entry_point, '__init__.py')

        if dockerfile:
            cmd = '{} build -t {} -f {} .'.format(az_config.DOCKER_PATH,
                                                  docker_image_name,
                                                  dockerfile)
        else:
            cmd = '{} build -t {} .'.format(az_config.DOCKER_PATH, docker_image_name)

        if logger.getEffectiveLevel() != logging.DEBUG:
            cmd = cmd + " >{} 2>&1".format(os.devnull)

        logger.info('Building default runtime')
        res = os.system(cmd)
        if res != 0:
            raise Exception('There was an error building the runtime')

        cmd = '{} push {}'.format(az_config.DOCKER_PATH, docker_image_name)
        if logger.getEffectiveLevel() != logging.DEBUG:
            cmd = cmd + " >{} 2>&1".format(os.devnull)
        res = os.system(cmd)
        if res != 0:
            raise Exception('There was an error pushing the runtime to the container registry')
        logger.debug('Done!')

    def _create_function(self, docker_image_name, memory=None,
                         timeout=az_config.RUNTIME_TIMEOUT):
        """
        Create and publish an Azure Function App
        """
        action_name = self._format_action_name(docker_image_name, memory)

        try:
            in_q_name = self._format_queue_name(action_name, az_config.IN_QUEUE)
            self.queue_service.create_queue(in_q_name)
        except Exception:
            in_queue = self.queue_service.get_queue_client(in_q_name)
            in_queue.clear_messages()
        try:
            out_q_name = self._format_queue_name(action_name, az_config.OUT_QUEUE)
            self.queue_service.create_queue(out_q_name)
        except Exception:
            out_queue = self.queue_service.get_queue_client(out_q_name)
            out_queue.clear_messages()

        logger.debug('Creating function app')
        logger.debug('Function name: {}'.format(action_name))
        python_version = version_str(sys.version_info)
        cmd = ('az functionapp create --name {} --storage-account {} '
               '--resource-group {} --os-type Linux  --runtime python '
               '--runtime-version {} --functions-version {} --consumption-plan-location {}'
               .format(action_name, self.storage_account, self.resource_group,
                       python_version, self.functions_version, self.location))
        if logger.getEffectiveLevel() != logging.DEBUG:
            cmd = cmd + " >{} 2>&1".format(os.devnull)
        res = os.system(cmd)
        if res != 0:
            raise Exception('There was an error creating the function in Azure. cmd: {}'.format(cmd))

        logger.debug('Publishing function app')
        build_dir = os.path.join(az_config.BUILD_DIR, action_name)
        os.chdir(build_dir)
        res = 1
        while res != 0:
            time.sleep(5)
            cmd = 'func azure functionapp publish {} --python --no-build'.format(action_name)
            if logger.getEffectiveLevel() != logging.DEBUG:
                cmd = cmd + " >{} 2>&1".format(os.devnull)
            res = os.system(cmd)

        time.sleep(10)

    def delete_runtime(self, runtime_name, memory):
        """
        Deletes a runtime
        """
        action_name = self._format_action_name(runtime_name, memory)

        logger.debug('Deleting function app: {}'.format(action_name))
        cmd = ('az functionapp delete --name {} --resource-group {}'
               .format(action_name, self.resource_group))
        if logger.getEffectiveLevel() != logging.DEBUG:
            cmd = cmd + " >{} 2>&1".format(os.devnull)

        try:
            in_q_name = self._format_queue_name(action_name, az_config.IN_QUEUE)
            self.queue_service.delete_queue(in_q_name)
        except Exception:
            pass
        try:
            out_q_name = self._format_queue_name(action_name, az_config.OUT_QUEUE)
            self.queue_service.delete_queue(out_q_name)
        except Exception:
            pass

    def invoke(self, docker_image_name, memory=None, payload={}, return_result=False):
        """
        Invoke function
        """
        action_name = self._format_action_name(docker_image_name)
        in_q_name = self._format_queue_name(action_name, az_config.IN_QUEUE)
        in_queue = self.queue_service.get_queue_client(in_q_name)
        msg = in_queue.send_message(dict_to_b64str(payload))
        activation_id = msg.id

        if return_result:
            out_q_name = self._format_queue_name(action_name, az_config.OUT_QUEUE)
            out_queue = self.queue_service.get_queue_client(out_q_name)
            msg = []
            while not msg:
                time.sleep(1)
                msg = out_queue.receive_message()
            out_queue.clear_messages()
            return b64str_to_dict(msg.content)

        return activation_id

    def get_runtime_key(self, docker_image_name, runtime_memory):
        """
        Method that creates and returns the runtime key.
        Runtime keys are used to uniquely identify runtimes within the storage,
        in order to know which runtimes are installed and which not.
        """
        action_name = self._format_action_name(docker_image_name, runtime_memory)
        runtime_key = os.path.join(self.name, action_name)

        return runtime_key

    def clean(self):
        # TODO
        pass

    def _generate_runtime_meta(self, docker_image_name, memory):
        """
        Extract installed Python modules from Azure runtime
        """
        logger.info("Extracting Python modules from: {}".format(docker_image_name))
        payload = {'log_level': logger.getEffectiveLevel(), 'get_preinstalls': True}

        try:
            runtime_meta = self.invoke(docker_image_name, memory=memory,
                                       payload=payload, return_result=True)
        except Exception:
            raise Exception("Unable to invoke 'extract-preinstalls' action")

        if not runtime_meta or 'preinstalls' not in runtime_meta:
            raise Exception(runtime_meta)

        logger.debug("Extracted metadata succesfully")
        return runtime_meta
Example #3
0
class AzureFunctionAppBackend:
    """
    A wrap-up around Azure Function Apps backend.
    """
    def __init__(self, config, internal_storage):
        logger.debug("Creating Azure Functions client")
        self.name = 'azure_fa'
        self.type = 'faas'
        self.azure_config = config
        self.invocation_type = self.azure_config['invocation_type']
        self.resource_group = self.azure_config['resource_group']
        self.storage_account_name = self.azure_config['storage_account_name']
        self.storage_account_key = self.azure_config['storage_account_key']
        self.location = self.azure_config['location']
        self.functions_version = self.azure_config['functions_version']

        self.queue_service_url = 'https://{}.queue.core.windows.net'.format(
            self.storage_account_name)
        self.queue_service = QueueServiceClient(
            account_url=self.queue_service_url,
            credential=self.storage_account_key)

        msg = COMPUTE_CLI_MSG.format('Azure Functions')
        logger.info("{} - Location: {}".format(msg, self.location))

    def _format_action_name(self, runtime_name, runtime_memory=None):
        runtime_name = runtime_name.replace('/', '--').replace(':', '--')
        return runtime_name

    def _format_queue_name(self, action_name, q_type):
        runtime_name = action_name.replace('--', '-')
        return runtime_name + '-' + q_type

    def _get_default_runtime_image_name(self):
        py_version = version_str(sys.version_info).replace('.', '')
        revision = 'latest' if 'dev' in __version__ else __version__.replace(
            '.', '')
        runtime_name = '{}-{}-v{}-{}-{}'.format(self.storage_account_name,
                                                az_config.RUNTIME_NAME,
                                                py_version, revision,
                                                self.invocation_type)
        return runtime_name

    def create_runtime(self, docker_image_name, memory, timeout):
        """
        Creates a new runtime into Azure Function Apps
        from the provided Linux image for consumption plan
        """
        default_runtime_img_name = self._get_default_runtime_image_name()
        if docker_image_name in ['default', default_runtime_img_name]:
            # We only build the default image. rest of images must already exist
            # in the docker registry.
            docker_image_name = default_runtime_img_name
            self._build_default_runtime(default_runtime_img_name)

        self._create_function(docker_image_name, memory, timeout)
        metadata = self._generate_runtime_meta(docker_image_name, memory)

        return metadata

    def _build_default_runtime(self, default_runtime_img_name):
        """
        Builds the default runtime
        """
        return self.build_runtime(default_runtime_img_name)

        if os.system('{} --version >{} 2>&1'.format(az_config.DOCKER_PATH,
                                                    os.devnull)) == 0:
            # Build default runtime using local dokcer
            python_version = version_str(sys.version_info)
            dockerfile = "Dockefile.default-azure-runtime"
            with open(dockerfile, 'w') as f:
                f.write(
                    "FROM mcr.microsoft.com/azure-functions/python:3.0-python{}\n"
                    .format(python_version))
                f.write(az_config.DEFAULT_DOCKERFILE)
            self.build_runtime_docker(default_runtime_img_name, dockerfile)
            os.remove(dockerfile)
        else:
            raise Exception('docker command not found. Install docker or use '
                            'an already built runtime')

    def build_runtime(self,
                      runtime_name,
                      requirements_file=None,
                      extra_args=[]):
        try:
            shutil.rmtree(az_config.BUILD_DIR)
        except Exception:
            pass

        action_name = self._format_action_name(runtime_name)

        build_dir = os.path.join(az_config.BUILD_DIR, action_name)
        os.makedirs(build_dir, exist_ok=True)

        logger.info('Building default runtime in {}'.format(build_dir))

        action_dir = os.path.join(build_dir, az_config.ACTION_DIR)
        os.makedirs(action_dir, exist_ok=True)

        req_file = os.path.join(build_dir, 'requirements.txt')
        with open(req_file, 'w') as reqf:
            reqf.write(az_config.REQUIREMENTS_FILE)
            if not is_unix_system():
                if 'dev' in lithops.__version__:
                    reqf.write('git+https://github.com/lithops-cloud/lithops')
                else:
                    reqf.write('lithops=={}'.format(lithops.__version__))

        host_file = os.path.join(build_dir, 'host.json')
        with open(host_file, 'w') as hstf:
            hstf.write(az_config.HOST_FILE)

        fn_file = os.path.join(action_dir, 'function.json')
        if self.invocation_type == 'event':
            with open(fn_file, 'w') as fnf:
                in_q_name = self._format_queue_name(action_name,
                                                    az_config.IN_QUEUE)
                az_config.BINDINGS_QUEUE['bindings'][0][
                    'queueName'] = in_q_name
                out_q_name = self._format_queue_name(action_name,
                                                     az_config.OUT_QUEUE)
                az_config.BINDINGS_QUEUE['bindings'][1][
                    'queueName'] = out_q_name
                fnf.write(json.dumps(az_config.BINDINGS_QUEUE))

        elif self.invocation_type == 'http':
            with open(fn_file, 'w') as fnf:
                fnf.write(json.dumps(az_config.BINDINGS_HTTP))

        entry_point = os.path.join(os.path.dirname(__file__), 'entry_point.py')
        main_file = os.path.join(action_dir, '__init__.py')
        shutil.copy(entry_point, main_file)

        if is_unix_system():
            mod_dir = os.path.join(build_dir, az_config.ACTION_MODULES_DIR)
            os.chdir(build_dir)
            cmd = '{} -m pip install -U -t {} -r requirements.txt'.format(
                sys.executable, mod_dir)
            if logger.getEffectiveLevel() != logging.DEBUG:
                cmd = cmd + " >{} 2>&1".format(os.devnull)
            os.system(cmd)
            create_handler_zip(az_config.FH_ZIP_LOCATION, entry_point,
                               '__init__.py')
            archive = zipfile.ZipFile(az_config.FH_ZIP_LOCATION)
            archive.extractall(path=mod_dir)
            os.remove(mod_dir + '/__init__.py')
            os.remove(az_config.FH_ZIP_LOCATION)

    def _create_function(self, docker_image_name, memory, timeout):
        """
        Create and publish an Azure Functions
        """
        action_name = self._format_action_name(docker_image_name, memory)
        logger.info(
            'Creating new Lithops runtime for Azure Function: {}'.format(
                action_name))

        if self.invocation_type == 'event':
            try:
                in_q_name = self._format_queue_name(action_name,
                                                    az_config.IN_QUEUE)
                logger.debug('Creating queue {}'.format(in_q_name))
                self.queue_service.create_queue(in_q_name)
            except Exception:
                in_queue = self.queue_service.get_queue_client(in_q_name)
                in_queue.clear_messages()
            try:
                out_q_name = self._format_queue_name(action_name,
                                                     az_config.OUT_QUEUE)
                logger.debug('Creating queue {}'.format(out_q_name))
                self.queue_service.create_queue(out_q_name)
            except Exception:
                out_queue = self.queue_service.get_queue_client(out_q_name)
                out_queue.clear_messages()

        python_version = version_str(sys.version_info)
        cmd = (
            'az functionapp create --name {} --storage-account {} '
            '--resource-group {} --os-type Linux  --runtime python '
            '--runtime-version {} --functions-version {} --consumption-plan-location {}'
            .format(action_name, self.storage_account_name,
                    self.resource_group, python_version,
                    self.functions_version, self.location))
        if logger.getEffectiveLevel() != logging.DEBUG:
            cmd = cmd + " >{} 2>&1".format(os.devnull)
        res = os.system(cmd)
        if res != 0:
            raise Exception(
                'There was an error creating the function in Azure. cmd: {}'.
                format(cmd))

        logger.debug('Publishing function: {}'.format(action_name))
        build_dir = os.path.join(az_config.BUILD_DIR, action_name)
        os.chdir(build_dir)
        res = 1
        while res != 0:
            time.sleep(5)
            if is_unix_system():
                cmd = 'func azure functionapp publish {} --python --no-build'.format(
                    action_name)
            else:
                cmd = 'func azure functionapp publish {} --python'.format(
                    action_name)
            if logger.getEffectiveLevel() != logging.DEBUG:
                cmd = cmd + " >{} 2>&1".format(os.devnull)
            res = os.system(cmd)

        time.sleep(10)

    def delete_runtime(self, runtime_name, memory):
        """
        Deletes a runtime
        """
        action_name = self._format_action_name(runtime_name, memory)
        logger.debug('Deleting function app: {}'.format(action_name))
        cmd = ('az functionapp delete --name {} --resource-group {}'.format(
            action_name, self.resource_group))
        if logger.getEffectiveLevel() != logging.DEBUG:
            cmd = cmd + " >{} 2>&1".format(os.devnull)
        os.system(cmd)

        try:
            in_q_name = self._format_queue_name(action_name,
                                                az_config.IN_QUEUE)
            self.queue_service.delete_queue(in_q_name)
        except Exception:
            pass
        try:
            out_q_name = self._format_queue_name(action_name,
                                                 az_config.OUT_QUEUE)
            self.queue_service.delete_queue(out_q_name)
        except Exception:
            pass

    def invoke(self,
               docker_image_name,
               memory=None,
               payload={},
               return_result=False):
        """
        Invoke function
        """
        action_name = self._format_action_name(docker_image_name, memory)
        if self.invocation_type == 'event':

            in_q_name = self._format_queue_name(action_name,
                                                az_config.IN_QUEUE)
            in_queue = self.queue_service.get_queue_client(in_q_name)
            msg = in_queue.send_message(dict_to_b64str(payload))
            activation_id = msg.id

            if return_result:
                out_q_name = self._format_queue_name(action_name,
                                                     az_config.OUT_QUEUE)
                out_queue = self.queue_service.get_queue_client(out_q_name)
                msg = []
                while not msg:
                    time.sleep(1)
                    msg = out_queue.receive_message()
                out_queue.clear_messages()
                return b64str_to_dict(msg.content)

        elif self.invocation_type == 'http':
            endpoint = "https://{}.azurewebsites.net".format(action_name)
            parsed_url = urlparse(endpoint)
            ctx = ssl._create_unverified_context()
            conn = http.client.HTTPSConnection(parsed_url.netloc, context=ctx)

            route = "/api/lithops_handler"
            if return_result:
                conn.request("GET",
                             route,
                             body=json.dumps(payload, default=str))
                resp = conn.getresponse()
                data = json.loads(resp.read().decode("utf-8"))
                conn.close()
                return data
            else:
                # logger.debug('Invoking calls {}'.format(', '.join(payload['call_ids'])))
                conn.request("POST",
                             route,
                             body=json.dumps(payload, default=str))
                resp = conn.getresponse()
                if resp.status == 429:
                    time.sleep(0.2)
                    conn.close()
                    return None
                activation_id = resp.read().decode("utf-8")
                conn.close()

        return activation_id

    def get_runtime_key(self, docker_image_name, runtime_memory):
        """
        Method that creates and returns the runtime key.
        Runtime keys are used to uniquely identify runtimes within the storage,
        in order to know which runtimes are installed and which not.
        """
        action_name = self._format_action_name(docker_image_name,
                                               runtime_memory)
        runtime_key = os.path.join(self.name, action_name)

        return runtime_key

    def clean(self):
        """
        Deletes all Lithops Azure Function Apps runtimes
        """
        logger.debug('Deleting all runtimes')

        runtimes = self.list_runtimes()

        for runtime in runtimes:
            runtime_name, runtime_memory = runtime
            self.delete_runtime(runtime_name, runtime_memory)

    def _generate_runtime_meta(self, docker_image_name, memory):
        """
        Extract installed Python modules from Azure runtime
        """
        logger.info(
            "Extracting Python modules from: {}".format(docker_image_name))
        payload = {
            'log_level': logger.getEffectiveLevel(),
            'get_preinstalls': True
        }

        try:
            runtime_meta = self.invoke(docker_image_name,
                                       memory=memory,
                                       payload=payload,
                                       return_result=True)
        except Exception:
            raise Exception("Unable to invoke 'extract-preinstalls' action")

        if not runtime_meta or 'preinstalls' not in runtime_meta:
            raise Exception(runtime_meta)

        logger.debug("Extracted metadata succesfully")
        return runtime_meta

    def list_runtimes(self, docker_image_name='all'):
        """
        List all the Azure Function Apps deployed.
        return: Array of tuples (function_name, memory)
        """
        logger.debug('Listing all functions deployed...')

        functions = []
        response = os.popen(
            'az functionapp list --query "[].defaultHostName\"').read()
        response = json.loads(response)

        for function in response:
            function = function.replace('.azurewebsites.net', '')
            if docker_image_name == function or docker_image_name == 'all':
                functions.append((function, ''))

        logger.debug('Listed {} functions'.format(len(functions)))
        return functions
class StorageQueueEncodingTest(QueueTestCase):
    def setUp(self):
        super(StorageQueueEncodingTest, self).setUp()

        queue_url = self._get_queue_url()
        credentials = self._get_shared_key_credential()
        self.qsc = QueueServiceClient(account_url=queue_url, credential=credentials)
        self.test_queues = []

    def tearDown(self):
        if not self.is_playback():
            for queue in self.test_queues:
                try:
                    self.qsc.delete_queue(queue.queue_name)
                except:
                    pass
        return super(StorageQueueEncodingTest, self).tearDown()

    # --Helpers-----------------------------------------------------------------
    def _get_queue_reference(self, prefix=TEST_QUEUE_PREFIX):
        queue_name = self.get_resource_name(prefix)
        queue = self.qsc.get_queue_client(queue_name)
        self.test_queues.append(queue)
        return queue

    def _create_queue(self, prefix=TEST_QUEUE_PREFIX):
        queue = self._get_queue_reference(prefix)
        try:
            created = queue.create_queue()
        except ResourceExistsError:
            pass
        return queue

    def _validate_encoding(self, queue, message):
        # Arrange
        try:
            created = queue.create_queue()
        except ResourceExistsError:
            pass

        # Action.
        queue.enqueue_message(message)

        # Asserts
        dequeued = next(queue.receive_messages())
        self.assertEqual(message, dequeued.content)

    # --------------------------------------------------------------------------

    @record
    def test_message_text_xml(self):
        # Arrange.
        message = u'<message1>'
        queue = self.qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX))

        # Asserts
        self._validate_encoding(queue, message)

    @record
    def test_message_text_xml_whitespace(self):
        # Arrange.
        message = u'  mess\t age1\n'
        queue = self.qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX))

        # Asserts
        self._validate_encoding(queue, message)

    @record
    def test_message_text_xml_invalid_chars(self):
        # Action.
        queue = self._get_queue_reference()
        message = u'\u0001'

        # Asserts
        with self.assertRaises(HttpResponseError):
            queue.enqueue_message(message)

    @record
    def test_message_text_base64(self):
        # Arrange.
        queue_url = self._get_queue_url()
        credentials = self._get_shared_key_credential()
        queue = QueueClient(
            queue_url=queue_url,
            queue=self.get_resource_name(TEST_QUEUE_PREFIX),
            credential=credentials,
            message_encode_policy=TextBase64EncodePolicy(),
            message_decode_policy=TextBase64DecodePolicy())

        message = u'\u0001'

        # Asserts
        self._validate_encoding(queue, message)

    @record
    def test_message_bytes_base64(self):
        # Arrange.
        queue_url = self._get_queue_url()
        credentials = self._get_shared_key_credential()
        queue = QueueClient(
            queue_url=queue_url,
            queue=self.get_resource_name(TEST_QUEUE_PREFIX),
            credential=credentials,
            message_encode_policy=BinaryBase64EncodePolicy(),
            message_decode_policy=BinaryBase64DecodePolicy())

        message = b'xyz'

        # Asserts
        self._validate_encoding(queue, message)

    @record
    def test_message_bytes_fails(self):
        # Arrange
        queue = self._get_queue_reference()

        # Action.
        with self.assertRaises(TypeError) as e:
            message = b'xyz'
            queue.enqueue_message(message)

        # Asserts
        self.assertTrue(str(e.exception).startswith('Message content must be text'))

    @record
    def test_message_text_fails(self):
        # Arrange
        queue_url = self._get_queue_url()
        credentials = self._get_shared_key_credential()
        queue = QueueClient(
            queue_url=queue_url,
            queue=self.get_resource_name(TEST_QUEUE_PREFIX),
            credential=credentials,
            message_encode_policy=BinaryBase64EncodePolicy(),
            message_decode_policy=BinaryBase64DecodePolicy())

        # Action.
        with self.assertRaises(TypeError) as e:
            message = u'xyz'
            queue.enqueue_message(message)

        # Asserts
        self.assertTrue(str(e.exception).startswith('Message content must be bytes'))

    @record
    def test_message_base64_decode_fails(self):
        # Arrange
        queue_url = self._get_queue_url()
        credentials = self._get_shared_key_credential()
        queue = QueueClient(
            queue_url=queue_url,
            queue=self.get_resource_name(TEST_QUEUE_PREFIX),
            credential=credentials,
            message_encode_policy=TextXMLEncodePolicy(),
            message_decode_policy=BinaryBase64DecodePolicy())
        try:
            queue.create_queue()
        except ResourceExistsError:
            pass
        message = u'xyz'
        queue.enqueue_message(message)

        # Action.
        with self.assertRaises(DecodeError) as e:
            queue.peek_messages()

        # Asserts
        self.assertNotEqual(-1, str(e.exception).find('Message content is not valid base 64'))
class StorageQueueEncryptionTest(QueueTestCase):
    def setUp(self):
        super(StorageQueueEncryptionTest, self).setUp()

        queue_url = self._get_queue_url()
        credentials = self._get_shared_key_credential()
        self.qsc = QueueServiceClient(account_url=queue_url,
                                      credential=credentials)
        self.test_queues = []

    def tearDown(self):
        if not self.is_playback():
            for queue in self.test_queues:
                try:
                    self.qsc.delete_queue(queue.queue_name)
                except:
                    pass
        return super(StorageQueueEncryptionTest, self).tearDown()

    # --Helpers-----------------------------------------------------------------
    def _get_queue_reference(self, prefix=TEST_QUEUE_PREFIX):
        queue_name = self.get_resource_name(prefix)
        queue = self.qsc.get_queue_client(queue_name)
        self.test_queues.append(queue)
        return queue

    def _create_queue(self, prefix=TEST_QUEUE_PREFIX):
        queue = self._get_queue_reference(prefix)
        try:
            created = queue.create_queue()
        except ResourceExistsError:
            pass
        return queue

    # --------------------------------------------------------------------------

    @record
    def test_get_messages_encrypted_kek(self):
        # Arrange
        self.qsc.key_encryption_key = KeyWrapper('key1')
        queue = self._create_queue()
        queue.enqueue_message(u'encrypted_message_2')

        # Act
        li = next(queue.receive_messages())

        # Assert
        self.assertEqual(li.content, u'encrypted_message_2')

    @record
    def test_get_messages_encrypted_resolver(self):
        # Arrange
        self.qsc.key_encryption_key = KeyWrapper('key1')
        queue = self._create_queue()
        queue.enqueue_message(u'encrypted_message_2')
        key_resolver = KeyResolver()
        key_resolver.put_key(self.qsc.key_encryption_key)
        queue.key_resolver_function = key_resolver.resolve_key
        queue.key_encryption_key = None  # Ensure that the resolver is used

        # Act
        li = next(queue.receive_messages())

        # Assert
        self.assertEqual(li.content, u'encrypted_message_2')

    @record
    def test_peek_messages_encrypted_kek(self):
        # Arrange
        self.qsc.key_encryption_key = KeyWrapper('key1')
        queue = self._create_queue()
        queue.enqueue_message(u'encrypted_message_3')

        # Act
        li = queue.peek_messages()

        # Assert
        self.assertEqual(li[0].content, u'encrypted_message_3')

    @record
    def test_peek_messages_encrypted_resolver(self):
        # Arrange
        self.qsc.key_encryption_key = KeyWrapper('key1')
        queue = self._create_queue()
        queue.enqueue_message(u'encrypted_message_4')
        key_resolver = KeyResolver()
        key_resolver.put_key(self.qsc.key_encryption_key)
        queue.key_resolver_function = key_resolver.resolve_key
        queue.key_encryption_key = None  # Ensure that the resolver is used

        # Act
        li = queue.peek_messages()

        # Assert
        self.assertEqual(li[0].content, u'encrypted_message_4')

    def test_peek_messages_encrypted_kek_RSA(self):

        # We can only generate random RSA keys, so this must be run live or
        # the playback test will fail due to a change in kek values.
        if TestMode.need_recording_file(self.test_mode):
            return

            # Arrange
        self.qsc.key_encryption_key = RSAKeyWrapper('key2')
        queue = self._create_queue()
        queue.enqueue_message(u'encrypted_message_3')

        # Act
        li = queue.peek_messages()

        # Assert
        self.assertEqual(li[0].content, u'encrypted_message_3')

    @record
    def test_update_encrypted_message(self):
        # TODO: Recording doesn't work
        if TestMode.need_recording_file(self.test_mode):
            return
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue.enqueue_message(u'Update Me')

        messages = queue.receive_messages()
        list_result1 = next(messages)
        list_result1.content = u'Updated'

        # Act
        message = queue.update_message(list_result1)
        list_result2 = next(messages)

        # Assert
        self.assertEqual(u'Updated', list_result2.content)

    @record
    def test_update_encrypted_binary_message(self):
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue._config.message_encode_policy = BinaryBase64EncodePolicy()
        queue._config.message_decode_policy = BinaryBase64DecodePolicy()

        binary_message = self.get_random_bytes(100)
        queue.enqueue_message(binary_message)
        messages = queue.receive_messages()
        list_result1 = next(messages)

        # Act
        binary_message = self.get_random_bytes(100)
        list_result1.content = binary_message
        queue.update_message(list_result1)

        list_result2 = next(messages)

        # Assert
        self.assertEqual(binary_message, list_result2.content)

    @record
    def test_update_encrypted_raw_text_message(self):
        # TODO: Recording doesn't work
        if TestMode.need_recording_file(self.test_mode):
            return
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue._config.message_encode_policy = NoEncodePolicy()
        queue._config.message_decode_policy = NoDecodePolicy()

        raw_text = u'Update Me'
        queue.enqueue_message(raw_text)
        messages = queue.receive_messages()
        list_result1 = next(messages)

        # Act
        raw_text = u'Updated'
        list_result1.content = raw_text
        queue.update_message(list_result1)

        list_result2 = next(messages)

        # Assert
        self.assertEqual(raw_text, list_result2.content)

    @record
    def test_update_encrypted_json_message(self):
        # TODO: Recording doesn't work
        if TestMode.need_recording_file(self.test_mode):
            return
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue._config.message_encode_policy = NoEncodePolicy()
        queue._config.message_decode_policy = NoDecodePolicy()

        message_dict = {'val1': 1, 'val2': '2'}
        json_text = dumps(message_dict)
        queue.enqueue_message(json_text)
        messages = queue.receive_messages()
        list_result1 = next(messages)

        # Act
        message_dict['val1'] = 0
        message_dict['val2'] = 'updated'
        json_text = dumps(message_dict)
        list_result1.content = json_text
        queue.update_message(list_result1)

        list_result2 = next(messages)

        # Assert
        self.assertEqual(message_dict, loads(list_result2.content))

    @record
    def test_invalid_value_kek_wrap(self):
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue.key_encryption_key.get_kid = None

        with self.assertRaises(AttributeError) as e:
            queue.enqueue_message(u'message')

        self.assertEqual(
            str(e.exception),
            _ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid'))

        queue.key_encryption_key = KeyWrapper('key1')
        queue.key_encryption_key.get_kid = None
        with self.assertRaises(AttributeError):
            queue.enqueue_message(u'message')

        queue.key_encryption_key = KeyWrapper('key1')
        queue.key_encryption_key.wrap_key = None
        with self.assertRaises(AttributeError):
            queue.enqueue_message(u'message')

    @record
    def test_missing_attribute_kek_wrap(self):
        # Arrange
        queue = self._create_queue()

        valid_key = KeyWrapper('key1')

        # Act
        invalid_key_1 = lambda: None  # functions are objects, so this effectively creates an empty object
        invalid_key_1.get_key_wrap_algorithm = valid_key.get_key_wrap_algorithm
        invalid_key_1.get_kid = valid_key.get_kid
        # No attribute wrap_key
        queue.key_encryption_key = invalid_key_1
        with self.assertRaises(AttributeError):
            queue.enqueue_message(u'message')

        invalid_key_2 = lambda: None  # functions are objects, so this effectively creates an empty object
        invalid_key_2.wrap_key = valid_key.wrap_key
        invalid_key_2.get_kid = valid_key.get_kid
        # No attribute get_key_wrap_algorithm
        queue.key_encryption_key = invalid_key_2
        with self.assertRaises(AttributeError):
            queue.enqueue_message(u'message')

        invalid_key_3 = lambda: None  # functions are objects, so this effectively creates an empty object
        invalid_key_3.get_key_wrap_algorithm = valid_key.get_key_wrap_algorithm
        invalid_key_3.wrap_key = valid_key.wrap_key
        # No attribute get_kid
        queue.key_encryption_key = invalid_key_3
        with self.assertRaises(AttributeError):
            queue.enqueue_message(u'message')

    @record
    def test_invalid_value_kek_unwrap(self):
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue.enqueue_message(u'message')

        # Act
        queue.key_encryption_key.unwrap_key = None
        with self.assertRaises(HttpResponseError):
            queue.peek_messages()

        queue.key_encryption_key.get_kid = None
        with self.assertRaises(HttpResponseError):
            queue.peek_messages()

    @record
    def test_missing_attribute_kek_unrwap(self):
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue.enqueue_message(u'message')

        # Act
        valid_key = KeyWrapper('key1')
        invalid_key_1 = lambda: None  # functions are objects, so this effectively creates an empty object
        invalid_key_1.unwrap_key = valid_key.unwrap_key
        # No attribute get_kid
        queue.key_encryption_key = invalid_key_1
        with self.assertRaises(HttpResponseError) as e:
            queue.peek_messages()

        self.assertEqual(str(e.exception), "Decryption failed.")

        invalid_key_2 = lambda: None  # functions are objects, so this effectively creates an empty object
        invalid_key_2.get_kid = valid_key.get_kid
        # No attribute unwrap_key
        queue.key_encryption_key = invalid_key_2
        with self.assertRaises(HttpResponseError):
            queue.peek_messages()

    @record
    def test_validate_encryption(self):
        # Arrange
        queue = self._create_queue()
        kek = KeyWrapper('key1')
        queue.key_encryption_key = kek
        queue.enqueue_message(u'message')

        # Act
        queue.key_encryption_key = None  # Message will not be decrypted
        li = queue.peek_messages()
        message = li[0].content
        message = loads(message)

        encryption_data = message['EncryptionData']

        wrapped_content_key = encryption_data['WrappedContentKey']
        wrapped_content_key = _WrappedContentKey(
            wrapped_content_key['Algorithm'],
            b64decode(
                wrapped_content_key['EncryptedKey'].encode(encoding='utf-8')),
            wrapped_content_key['KeyId'])

        encryption_agent = encryption_data['EncryptionAgent']
        encryption_agent = _EncryptionAgent(
            encryption_agent['EncryptionAlgorithm'],
            encryption_agent['Protocol'])

        encryption_data = _EncryptionData(
            b64decode(encryption_data['ContentEncryptionIV'].encode(
                encoding='utf-8')), encryption_agent, wrapped_content_key,
            {'EncryptionLibrary': VERSION})

        message = message['EncryptedMessageContents']
        content_encryption_key = kek.unwrap_key(
            encryption_data.wrapped_content_key.encrypted_key,
            encryption_data.wrapped_content_key.algorithm)

        # Create decryption cipher
        backend = backends.default_backend()
        algorithm = AES(content_encryption_key)
        mode = CBC(encryption_data.content_encryption_IV)
        cipher = Cipher(algorithm, mode, backend)

        # decode and decrypt data
        decrypted_data = _decode_base64_to_bytes(message)
        decryptor = cipher.decryptor()
        decrypted_data = (decryptor.update(decrypted_data) +
                          decryptor.finalize())

        # unpad data
        unpadder = PKCS7(128).unpadder()
        decrypted_data = (unpadder.update(decrypted_data) +
                          unpadder.finalize())

        decrypted_data = decrypted_data.decode(encoding='utf-8')

        # Assert
        self.assertEqual(decrypted_data, u'message')

    @record
    def test_put_with_strict_mode(self):
        # Arrange
        queue = self._create_queue()
        kek = KeyWrapper('key1')
        queue.key_encryption_key = kek
        queue.require_encryption = True

        queue.enqueue_message(u'message')
        queue.key_encryption_key = None

        # Assert
        with self.assertRaises(ValueError) as e:
            queue.enqueue_message(u'message')

        self.assertEqual(str(e.exception),
                         "Encryption required but no key was provided.")

    @record
    def test_get_with_strict_mode(self):
        # Arrange
        queue = self._create_queue()
        queue.enqueue_message(u'message')

        queue.require_encryption = True
        queue.key_encryption_key = KeyWrapper('key1')
        with self.assertRaises(ValueError) as e:
            next(queue.receive_messages())

        self.assertEqual(str(e.exception), 'Message was not encrypted.')

    @record
    def test_encryption_add_encrypted_64k_message(self):
        # Arrange
        queue = self._create_queue()
        message = u'a' * 1024 * 64

        # Act
        queue.enqueue_message(message)

        # Assert
        queue.key_encryption_key = KeyWrapper('key1')
        with self.assertRaises(HttpResponseError):
            queue.enqueue_message(message)

    @record
    def test_encryption_nonmatching_kid(self):
        # Arrange
        queue = self._create_queue()
        queue.key_encryption_key = KeyWrapper('key1')
        queue.enqueue_message(u'message')

        # Act
        queue.key_encryption_key.kid = 'Invalid'

        # Assert
        with self.assertRaises(HttpResponseError) as e:
            next(queue.receive_messages())

        self.assertEqual(str(e.exception), "Decryption failed.")
metadata = queue_service.get_queue_metadata('pizzaqueue')

print(
    'If we look at the Queue again, we have one less message to show we have processed that order and a yummy pizza will be on it\'s way to the customer soon.'
)
print('Number of messages in the queue: ' +
      str(metadata.approximate_message_count))
raw_input('\nPress Enter to continue...')

###
# This was a quick demo to see Queues in action.
# Although the actual cost is minimal since we deleted all the messages from the Queue, it's good to clean up resources when you're done
###
print(
    '\nThis is a basic example of how Azure Storage Queues behave.\nTo keep things tidy, let\'s clean up the Azure Storage resources we created.'
)
raw_input('Press Enter to continue...')

response = queue_service.delete_queue('pizzaqueue')
if response == True:
    print('Storage Queue: pizzaqueue deleted successfully.')
else:
    print('Error deleting Storage Queue')

response = azurerm.delete_resource_group(auth_token, subscription_id,
                                         resourcegroup_name)
if response.status_code == 202:
    print('Resource group: ' + resourcegroup_name + ' deleted successfully.')
else:
    print('Error deleting resource group.')