Exemple #1
0
class AzureJobStore(AbstractJobStore):
    """
    A job store that uses Azure's blob store for file storage and
    Table Service to store job info with strong consistency."""

    @classmethod
    def loadOrCreateJobStore(cls, locator, config=None, **kwargs):
        account, namePrefix = locator.split(':', 1)
        if '--' in namePrefix:
            raise ValueError("Invalid name prefix '%s'. Name prefixes may not contain "
                             "%s." % (namePrefix, cls.nameSeparator))

        if not cls.containerNameRe.match(namePrefix):
            raise ValueError("Invalid name prefix '%s'. Name prefixes must contain only digits, "
                             "hyphens or lower-case letters and must not start or end in a "
                             "hyphen." % namePrefix)

        # Reserve 13 for separator and suffix
        if len(namePrefix) > cls.maxContainerNameLen - cls.maxNameLen - len(cls.nameSeparator):
            raise ValueError(("Invalid name prefix '%s'. Name prefixes may not be longer than 50 "
                              "characters." % namePrefix))

        if '--' in namePrefix:
            raise ValueError("Invalid name prefix '%s'. Name prefixes may not contain "
                             "%s." % (namePrefix, cls.nameSeparator))

        return cls(account, namePrefix, config=config, **kwargs)

    # Dots in container names should be avoided because container names are used in HTTPS bucket
    # URLs where the may interfere with the certificate common name. We use a double
    # underscore as a separator instead.
    #
    containerNameRe = re.compile(r'^[a-z0-9](-?[a-z0-9]+)+[a-z0-9]$')

    # See https://msdn.microsoft.com/en-us/library/azure/dd135715.aspx
    #
    minContainerNameLen = 3
    maxContainerNameLen = 63
    maxNameLen = 10
    nameSeparator = 'xx'  # Table names must be alphanumeric

    # Do not invoke the constructor, use the factory method above.

    def __init__(self, accountName, namePrefix, config=None,
                 jobChunkSize=maxAzureTablePropertySize):
        self.jobChunkSize = jobChunkSize
        self.keyPath = None

        self.account_key = _fetchAzureAccountKey(accountName)
        self.accountName = accountName
        # Table names have strict requirements in Azure
        self.namePrefix = self._sanitizeTableName(namePrefix)
        logger.debug("Creating job store with name prefix '%s'" % self.namePrefix)

        # These are the main API entrypoints.
        self.tableService = TableService(account_key=self.account_key, account_name=accountName)
        self.blobService = BlobService(account_key=self.account_key, account_name=accountName)

        exists = self._jobStoreExists()
        self._checkJobStoreCreation(config is not None, exists, accountName + ":" + self.namePrefix)

        # Serialized jobs table
        self.jobItems = self._getOrCreateTable(self.qualify('jobs'))
        # Job<->file mapping table
        self.jobFileIDs = self._getOrCreateTable(self.qualify('jobFileIDs'))

        # Container for all shared and unshared files
        self.files = self._getOrCreateBlobContainer(self.qualify('files'))

        # Stats and logging strings
        self.statsFiles = self._getOrCreateBlobContainer(self.qualify('statsfiles'))
        # File IDs that contain stats and logging strings
        self.statsFileIDs = self._getOrCreateTable(self.qualify('statsFileIDs'))

        super(AzureJobStore, self).__init__(config=config)

        if self.config.cseKey is not None:
            self.keyPath = self.config.cseKey

    # Length of a jobID - used to test if a stats file has been read already or not
    jobIDLength = len(str(uuid.uuid4()))

    def qualify(self, name):
        return self.namePrefix + self.nameSeparator + name

    def jobs(self):

        # How many jobs have we done?
        total_processed = 0

        for jobEntity in self.jobItems.query_entities_auto():
            # Process the items in the page
            yield AzureJob.fromEntity(jobEntity)
            total_processed += 1

            if total_processed % 1000 == 0:
                # Produce some feedback for the user, because this can take
                # a long time on, for example, Azure
                logger.info("Processed %d total jobs" % total_processed)

        logger.info("Processed %d total jobs" % total_processed)

    def create(self, command, memory, cores, disk, preemptable, predecessorNumber=0):
        jobStoreID = self._newJobID()
        job = AzureJob(jobStoreID=jobStoreID, command=command,
                       memory=memory, cores=cores, disk=disk, preemptable=preemptable,
                       remainingRetryCount=self._defaultTryCount(), logJobStoreFileID=None,
                       predecessorNumber=predecessorNumber)
        entity = job.toItem(chunkSize=self.jobChunkSize)
        entity['RowKey'] = jobStoreID
        self.jobItems.insert_entity(entity=entity)
        return job

    def exists(self, jobStoreID):
        if self.jobItems.get_entity(row_key=jobStoreID) is None:
            return False
        return True

    def load(self, jobStoreID):
        jobEntity = self.jobItems.get_entity(row_key=jobStoreID)
        if jobEntity is None:
            raise NoSuchJobException(jobStoreID)
        return AzureJob.fromEntity(jobEntity)

    def update(self, job):
        self.jobItems.update_entity(row_key=job.jobStoreID,
                                    entity=job.toItem(chunkSize=self.jobChunkSize))

    def delete(self, jobStoreID):
        try:
            self.jobItems.delete_entity(row_key=jobStoreID)
        except AzureMissingResourceHttpError:
            # Job deletion is idempotent, and this job has been deleted already
            return
        filterString = "PartitionKey eq '%s'" % jobStoreID
        for fileEntity in self.jobFileIDs.query_entities(filter=filterString):
            jobStoreFileID = fileEntity.RowKey
            self.deleteFile(jobStoreFileID)

    def deleteJobStore(self):
        self.jobItems.delete_table()
        self.jobFileIDs.delete_table()
        self.files.delete_container()
        self.statsFiles.delete_container()
        self.statsFileIDs.delete_table()

    def _jobStoreExists(self):
        """
        Checks if job store exists by querying the existence of the statsFileIDs table. Note that
        this is the last component that is deleted in deleteJobStore.
        """
        for attempt in retry_azure():
            with attempt:
                try:
                    table = self.tableService.query_tables(table_name=self.qualify('statsFileIDs'))
                    return table is not None
                except AzureMissingResourceHttpError as e:
                    if e.status_code == 404:
                        return False
                    else:
                        raise

    def getEnv(self):
        return dict(AZURE_ACCOUNT_KEY=self.account_key)

    @classmethod
    def _readFromUrl(cls, url, writable):
        blobService, containerName, blobName = cls._extractBlobInfoFromUrl(url)
        blobService.get_blob_to_file(containerName, blobName, writable)

    @classmethod
    def _writeToUrl(cls, readable, url):
        blobService, containerName, blobName = cls._extractBlobInfoFromUrl(url)
        blobService.put_block_blob_from_file(containerName, blobName, readable)
        blobService.get_blob(containerName, blobName)

    @staticmethod
    def _extractBlobInfoFromUrl(url):
        """
        :return: (blobService, containerName, blobName)
        """

        def invalidUrl():
            raise RuntimeError("The URL '%s' is invalid" % url.geturl())

        netloc = url.netloc.split('@')
        if len(netloc) != 2:
            invalidUrl()

        accountEnd = netloc[1].find('.blob.core.windows.net')
        if accountEnd == -1:
            invalidUrl()

        containerName, accountName = netloc[0], netloc[1][0:accountEnd]
        blobName = url.path[1:]  # urlparse always includes a leading '/'
        blobService = BlobService(account_key=_fetchAzureAccountKey(accountName),
                                  account_name=accountName)
        return blobService, containerName, blobName

    @classmethod
    def _supportsUrl(cls, url, export=False):
        return url.scheme.lower() == 'wasb' or url.scheme.lower() == 'wasbs'

    def writeFile(self, localFilePath, jobStoreID=None):
        jobStoreFileID = self._newFileID()
        self.updateFile(jobStoreFileID, localFilePath)
        self._associateFileWithJob(jobStoreFileID, jobStoreID)
        return jobStoreFileID

    def updateFile(self, jobStoreFileID, localFilePath):
        with open(localFilePath) as read_fd:
            with self._uploadStream(jobStoreFileID, self.files) as write_fd:
                while True:
                    buf = read_fd.read(self._maxAzureBlockBytes)
                    write_fd.write(buf)
                    if len(buf) == 0:
                        break

    def readFile(self, jobStoreFileID, localFilePath):
        try:
            with self._downloadStream(jobStoreFileID, self.files) as read_fd:
                with open(localFilePath, 'w') as write_fd:
                    while True:
                        buf = read_fd.read(self._maxAzureBlockBytes)
                        write_fd.write(buf)
                        if not buf: break
        except AzureMissingResourceHttpError:
            raise NoSuchFileException(jobStoreFileID)

    def deleteFile(self, jobStoreFileID):
        try:
            self.files.delete_blob(blob_name=jobStoreFileID)
            self._dissociateFileFromJob(jobStoreFileID)
        except AzureMissingResourceHttpError:
            pass

    def fileExists(self, jobStoreFileID):
        # As Azure doesn't have a blob_exists method (at least in the
        # python API) we just try to download the metadata, and hope
        # the metadata is small so the call will be fast.
        try:
            self.files.get_blob_metadata(blob_name=jobStoreFileID)
            return True
        except AzureMissingResourceHttpError:
            return False

    @contextmanager
    def writeFileStream(self, jobStoreID=None):
        # TODO: this (and all stream methods) should probably use the
        # Append Blob type, but that is not currently supported by the
        # Azure Python API.
        jobStoreFileID = self._newFileID()
        with self._uploadStream(jobStoreFileID, self.files) as fd:
            yield fd, jobStoreFileID
        self._associateFileWithJob(jobStoreFileID, jobStoreID)

    @contextmanager
    def updateFileStream(self, jobStoreFileID):
        with self._uploadStream(jobStoreFileID, self.files, checkForModification=True) as fd:
            yield fd

    def getEmptyFileStoreID(self, jobStoreID=None):
        jobStoreFileID = self._newFileID()
        self.files.put_blob(blob_name=jobStoreFileID, blob='',
                            x_ms_blob_type='BlockBlob')
        self._associateFileWithJob(jobStoreFileID, jobStoreID)
        return jobStoreFileID

    @contextmanager
    def readFileStream(self, jobStoreFileID):
        if not self.fileExists(jobStoreFileID):
            raise NoSuchFileException(jobStoreFileID)
        with self._downloadStream(jobStoreFileID, self.files) as fd:
            yield fd

    @contextmanager
    def writeSharedFileStream(self, sharedFileName, isProtected=None):
        assert self._validateSharedFileName(sharedFileName)
        sharedFileID = self._newFileID(sharedFileName)
        with self._uploadStream(sharedFileID, self.files, encrypted=isProtected) as fd:
            yield fd

    @contextmanager
    def readSharedFileStream(self, sharedFileName):
        assert self._validateSharedFileName(sharedFileName)
        sharedFileID = self._newFileID(sharedFileName)
        if not self.fileExists(sharedFileID):
            raise NoSuchFileException(sharedFileID)
        with self._downloadStream(sharedFileID, self.files) as fd:
            yield fd

    def writeStatsAndLogging(self, statsAndLoggingString):
        # TODO: would be a great use case for the append blobs, once implemented in the Azure SDK
        jobStoreFileID = self._newFileID()
        encrypted = self.keyPath is not None
        if encrypted:
            statsAndLoggingString = encryption.encrypt(statsAndLoggingString, self.keyPath)
        self.statsFiles.put_block_blob_from_text(blob_name=jobStoreFileID,
                                                 text=statsAndLoggingString,
                                                 x_ms_meta_name_values=dict(
                                                     encrypted=str(encrypted)))
        self.statsFileIDs.insert_entity(entity={'RowKey': jobStoreFileID})

    def readStatsAndLogging(self, callback, readAll=False):
        suffix = '_old'
        numStatsFiles = 0
        for entity in self.statsFileIDs.query_entities():
            jobStoreFileID = entity.RowKey
            hasBeenRead = len(jobStoreFileID) > self.jobIDLength
            if not hasBeenRead:
                with self._downloadStream(jobStoreFileID, self.statsFiles) as fd:
                    callback(fd)
                # Mark this entity as read by appending the suffix
                self.statsFileIDs.insert_entity(entity={'RowKey': jobStoreFileID + suffix})
                self.statsFileIDs.delete_entity(row_key=jobStoreFileID)
                numStatsFiles += 1
            elif readAll:
                # Strip the suffix to get the original ID
                jobStoreFileID = jobStoreFileID[:-len(suffix)]
                with self._downloadStream(jobStoreFileID, self.statsFiles) as fd:
                    callback(fd)
                numStatsFiles += 1
        return numStatsFiles

    _azureTimeFormat = "%Y-%m-%dT%H:%M:%SZ"

    def getPublicUrl(self, jobStoreFileID):
        try:
            self.files.get_blob_properties(blob_name=jobStoreFileID)
        except AzureMissingResourceHttpError:
            raise NoSuchFileException(jobStoreFileID)
        # Compensate of a little bit of clock skew
        startTimeStr = (datetime.utcnow() - timedelta(minutes=5)).strftime(self._azureTimeFormat)
        endTime = datetime.utcnow() + self.publicUrlExpiration
        endTimeStr = endTime.strftime(self._azureTimeFormat)
        sap = SharedAccessPolicy(AccessPolicy(startTimeStr, endTimeStr,
                                              BlobSharedAccessPermissions.READ))
        sas_token = self.files.generate_shared_access_signature(blob_name=jobStoreFileID,
                                                                shared_access_policy=sap)
        return self.files.make_blob_url(blob_name=jobStoreFileID) + '?' + sas_token

    def getSharedPublicUrl(self, sharedFileName):
        jobStoreFileID = self._newFileID(sharedFileName)
        return self.getPublicUrl(jobStoreFileID)

    def _newJobID(self):
        # raw UUIDs don't work for Azure property names because the '-' character is disallowed.
        return str(uuid.uuid4()).replace('-', '_')

    # A dummy job ID under which all shared files are stored.
    sharedFileJobID = uuid.UUID('891f7db6-e4d9-4221-a58e-ab6cc4395f94')

    def _newFileID(self, sharedFileName=None):
        if sharedFileName is None:
            ret = str(uuid.uuid4())
        else:
            ret = str(uuid.uuid5(self.sharedFileJobID, str(sharedFileName)))
        return ret.replace('-', '_')

    def _associateFileWithJob(self, jobStoreFileID, jobStoreID=None):
        if jobStoreID is not None:
            self.jobFileIDs.insert_entity(entity={'PartitionKey': jobStoreID,
                                                  'RowKey': jobStoreFileID})

    def _dissociateFileFromJob(self, jobStoreFileID):
        entities = self.jobFileIDs.query_entities(filter="RowKey eq '%s'" % jobStoreFileID)
        if entities:
            assert len(entities) == 1
            jobStoreID = entities[0].PartitionKey
            self.jobFileIDs.delete_entity(partition_key=jobStoreID, row_key=jobStoreFileID)

    def _getOrCreateTable(self, tableName):
        # This will not fail if the table already exists.
        for attempt in retry_azure():
            with attempt:
                self.tableService.create_table(tableName)
        return AzureTable(self.tableService, tableName)

    def _getOrCreateBlobContainer(self, containerName):
        for attempt in retry_azure():
            with attempt:
                self.blobService.create_container(containerName)
        return AzureBlobContainer(self.blobService, containerName)

    def _sanitizeTableName(self, tableName):
        """
        Azure table names must start with a letter and be alphanumeric.

        This will never cause a collision if uuids are used, but
        otherwise may not be safe.
        """
        return 'a' + filter(lambda x: x.isalnum(), tableName)

    # Maximum bytes that can be in any block of an Azure block blob
    # https://github.com/Azure/azure-storage-python/blob/4c7666e05a9556c10154508335738ee44d7cb104/azure/storage/blob/blobservice.py#L106
    _maxAzureBlockBytes = 4 * 1024 * 1024

    @contextmanager
    def _uploadStream(self, jobStoreFileID, container, checkForModification=False, encrypted=None):
        """
        :param encrypted: True to enforce encryption (will raise exception unless key is set),
        False to prevent encryption or None to encrypt if key is set.
        """
        if checkForModification:
            try:
                expectedVersion = container.get_blob_properties(blob_name=jobStoreFileID)['etag']
            except AzureMissingResourceHttpError:
                expectedVersion = None

        if encrypted is None:
            encrypted = self.keyPath is not None
        elif encrypted:
            if self.keyPath is None:
                raise RuntimeError('Encryption requested but no key was provided')

        maxBlockSize = self._maxAzureBlockBytes
        if encrypted:
            # There is a small overhead for encrypted data.
            maxBlockSize -= encryption.overhead
        readable_fh, writable_fh = os.pipe()
        with os.fdopen(readable_fh, 'r') as readable:
            with os.fdopen(writable_fh, 'w') as writable:
                def reader():
                    blockIDs = []
                    try:
                        while True:
                            buf = readable.read(maxBlockSize)
                            if len(buf) == 0:
                                # We're safe to break here even if we never read anything, since
                                # putting an empty block list creates an empty blob.
                                break
                            if encrypted:
                                buf = encryption.encrypt(buf, self.keyPath)
                            blockID = self._newFileID()
                            container.put_block(blob_name=jobStoreFileID,
                                                block=buf,
                                                blockid=blockID)
                            blockIDs.append(blockID)
                    except:
                        # This is guaranteed to delete any uncommitted
                        # blocks.
                        container.delete_blob(blob_name=jobStoreFileID)
                        raise

                    if checkForModification and expectedVersion is not None:
                        # Acquire a (60-second) write lock,
                        leaseID = container.lease_blob(blob_name=jobStoreFileID,
                                                       x_ms_lease_action='acquire')['x-ms-lease-id']
                        # check for modification,
                        blobProperties = container.get_blob_properties(blob_name=jobStoreFileID)
                        if blobProperties['etag'] != expectedVersion:
                            container.lease_blob(blob_name=jobStoreFileID,
                                                 x_ms_lease_action='release',
                                                 x_ms_lease_id=leaseID)
                            raise ConcurrentFileModificationException(jobStoreFileID)
                        # commit the file,
                        container.put_block_list(blob_name=jobStoreFileID,
                                                 block_list=blockIDs,
                                                 x_ms_lease_id=leaseID,
                                                 x_ms_meta_name_values=dict(
                                                     encrypted=str(encrypted)))
                        # then release the lock.
                        container.lease_blob(blob_name=jobStoreFileID,
                                             x_ms_lease_action='release',
                                             x_ms_lease_id=leaseID)
                    else:
                        # No need to check for modification, just blindly write over whatever
                        # was there.
                        container.put_block_list(blob_name=jobStoreFileID,
                                                 block_list=blockIDs,
                                                 x_ms_meta_name_values=dict(
                                                     encrypted=str(encrypted)))

                thread = ExceptionalThread(target=reader)
                thread.start()
                yield writable
            # The writable is now closed. This will send EOF to the readable and cause that
            # thread to finish.
            thread.join()

    @contextmanager
    def _downloadStream(self, jobStoreFileID, container):
        # The reason this is not in the writer is so we catch non-existant blobs early

        blobProps = container.get_blob_properties(blob_name=jobStoreFileID)

        encrypted = strict_bool(blobProps['x-ms-meta-encrypted'])
        if encrypted and self.keyPath is None:
            raise AssertionError('Content is encrypted but no key was provided.')

        readable_fh, writable_fh = os.pipe()
        with os.fdopen(readable_fh, 'r') as readable:
            with os.fdopen(writable_fh, 'w') as writable:
                def writer():
                    try:
                        chunkStartPos = 0
                        fileSize = int(blobProps['Content-Length'])
                        while chunkStartPos < fileSize:
                            chunkEndPos = chunkStartPos + self._maxAzureBlockBytes - 1
                            buf = container.get_blob(blob_name=jobStoreFileID,
                                                     x_ms_range="bytes=%d-%d" % (chunkStartPos,
                                                                                 chunkEndPos))
                            if encrypted:
                                buf = encryption.decrypt(buf, self.keyPath)
                            writable.write(buf)
                            chunkStartPos = chunkEndPos + 1
                    finally:
                        # Ensure readers aren't left blocking if this thread crashes.
                        # This close() will send EOF to the reading end and ultimately cause the
                        # yield to return. It also makes the implict .close() done by the enclosing
                        # "with" context redundant but that should be ok since .close() on file
                        # objects are idempotent.
                        writable.close()

                thread = ExceptionalThread(target=writer)
                thread.start()
                yield readable
                thread.join()
Exemple #2
0
class AzureJobStore(AbstractJobStore):
    """
    A job store that uses Azure's blob store for file storage and Table Service to store job info
    with strong consistency.
    """

    # Dots in container names should be avoided because container names are used in HTTPS bucket
    # URLs where the may interfere with the certificate common name. We use a double underscore
    # as a separator instead.
    #
    containerNameRe = re.compile(r'^[a-z0-9](-?[a-z0-9]+)+[a-z0-9]$')

    # See https://msdn.microsoft.com/en-us/library/azure/dd135715.aspx
    #
    minContainerNameLen = 3
    maxContainerNameLen = 63
    maxNameLen = 10
    nameSeparator = 'xx'  # Table names must be alphanumeric
    # Length of a jobID - used to test if a stats file has been read already or not
    jobIDLength = len(str(uuid.uuid4()))

    def __init__(self, locator, jobChunkSize=maxAzureTablePropertySize):
        super(AzureJobStore, self).__init__()
        accountName, namePrefix = locator.split(':', 1)
        if '--' in namePrefix:
            raise ValueError(
                "Invalid name prefix '%s'. Name prefixes may not contain %s." %
                (namePrefix, self.nameSeparator))
        if not self.containerNameRe.match(namePrefix):
            raise ValueError(
                "Invalid name prefix '%s'. Name prefixes must contain only digits, "
                "hyphens or lower-case letters and must not start or end in a "
                "hyphen." % namePrefix)
        # Reserve 13 for separator and suffix
        if len(namePrefix) > self.maxContainerNameLen - self.maxNameLen - len(
                self.nameSeparator):
            raise ValueError((
                "Invalid name prefix '%s'. Name prefixes may not be longer than 50 "
                "characters." % namePrefix))
        if '--' in namePrefix:
            raise ValueError(
                "Invalid name prefix '%s'. Name prefixes may not contain "
                "%s." % (namePrefix, self.nameSeparator))
        self.locator = locator
        self.jobChunkSize = jobChunkSize
        self.accountKey = _fetchAzureAccountKey(accountName)
        self.accountName = accountName
        # Table names have strict requirements in Azure
        self.namePrefix = self._sanitizeTableName(namePrefix)
        # These are the main API entry points.
        self.tableService = TableService(account_key=self.accountKey,
                                         account_name=accountName)
        self.blobService = BlobService(account_key=self.accountKey,
                                       account_name=accountName)
        # Serialized jobs table
        self.jobItems = None
        # Job<->file mapping table
        self.jobFileIDs = None
        # Container for all shared and unshared files
        self.files = None
        # Stats and logging strings
        self.statsFiles = None
        # File IDs that contain stats and logging strings
        self.statsFileIDs = None

    @property
    def keyPath(self):
        return self.config.cseKey

    def initialize(self, config):
        if self._jobStoreExists():
            raise JobStoreExistsException(self.locator)
        logger.debug("Creating job store at '%s'" % self.locator)
        self._bind(create=True)
        super(AzureJobStore, self).initialize(config)

    def resume(self):
        if not self._jobStoreExists():
            raise NoSuchJobStoreException(self.locator)
        logger.debug("Using existing job store at '%s'" % self.locator)
        self._bind(create=False)
        super(AzureJobStore, self).resume()

    def destroy(self):
        for name in 'jobItems', 'jobFileIDs', 'files', 'statsFiles', 'statsFileIDs':
            resource = getattr(self, name)
            if resource is not None:
                if isinstance(resource, AzureTable):
                    resource.delete_table()
                elif isinstance(resource, AzureBlobContainer):
                    resource.delete_container()
                else:
                    assert False
                setattr(self, name, None)

    def _jobStoreExists(self):
        """
        Checks if job store exists by querying the existence of the statsFileIDs table. Note that
        this is the last component that is deleted in :meth:`.destroy`.
        """
        for attempt in retry_azure():
            with attempt:
                try:
                    table = self.tableService.query_tables(
                        table_name=self._qualify('statsFileIDs'))
                except AzureMissingResourceHttpError as e:
                    if e.status_code == 404:
                        return False
                    else:
                        raise
                else:
                    return table is not None

    def _bind(self, create=False):
        table = self._bindTable
        container = self._bindContainer
        for name, binder in (('jobItems', table), ('jobFileIDs', table),
                             ('files', container), ('statsFiles', container),
                             ('statsFileIDs', table)):
            if getattr(self, name) is None:
                setattr(self, name, binder(self._qualify(name), create=create))

    def _qualify(self, name):
        return self.namePrefix + self.nameSeparator + name.lower()

    def jobs(self):

        # How many jobs have we done?
        total_processed = 0

        for jobEntity in self.jobItems.query_entities_auto():
            # Process the items in the page
            yield AzureJob.fromEntity(jobEntity)
            total_processed += 1

            if total_processed % 1000 == 0:
                # Produce some feedback for the user, because this can take
                # a long time on, for example, Azure
                logger.info("Processed %d total jobs" % total_processed)

        logger.info("Processed %d total jobs" % total_processed)

    def create(self,
               command,
               memory,
               cores,
               disk,
               preemptable,
               predecessorNumber=0):
        jobStoreID = self._newJobID()
        job = AzureJob(jobStoreID=jobStoreID,
                       command=command,
                       memory=memory,
                       cores=cores,
                       disk=disk,
                       preemptable=preemptable,
                       remainingRetryCount=self._defaultTryCount(),
                       logJobStoreFileID=None,
                       predecessorNumber=predecessorNumber)
        entity = job.toItem(chunkSize=self.jobChunkSize)
        entity['RowKey'] = jobStoreID
        self.jobItems.insert_entity(entity=entity)
        return job

    def exists(self, jobStoreID):
        if self.jobItems.get_entity(row_key=jobStoreID) is None:
            return False
        return True

    def load(self, jobStoreID):
        jobEntity = self.jobItems.get_entity(row_key=jobStoreID)
        if jobEntity is None:
            raise NoSuchJobException(jobStoreID)
        return AzureJob.fromEntity(jobEntity)

    def update(self, job):
        self.jobItems.update_entity(
            row_key=job.jobStoreID,
            entity=job.toItem(chunkSize=self.jobChunkSize))

    def delete(self, jobStoreID):
        try:
            self.jobItems.delete_entity(row_key=jobStoreID)
        except AzureMissingResourceHttpError:
            # Job deletion is idempotent, and this job has been deleted already
            return
        filterString = "PartitionKey eq '%s'" % jobStoreID
        for fileEntity in self.jobFileIDs.query_entities(filter=filterString):
            jobStoreFileID = fileEntity.RowKey
            self.deleteFile(jobStoreFileID)

    def getEnv(self):
        return dict(AZURE_ACCOUNT_KEY=self.accountKey)

    class BlobInfo(namedtuple('BlobInfo', ('account', 'container', 'name'))):
        @property
        @memoize
        def service(self):
            return BlobService(account_name=self.account,
                               account_key=_fetchAzureAccountKey(self.account))

    @classmethod
    def _readFromUrl(cls, url, writable):
        blob = cls._parseWasbUrl(url)
        blob.service.get_blob_to_file(container_name=blob.container,
                                      blob_name=blob.name,
                                      stream=writable)

    @classmethod
    def _writeToUrl(cls, readable, url):
        blob = cls._parseWasbUrl(url)
        blob.service.put_block_blob_from_file(container_name=blob.container,
                                              blob_name=blob.name,
                                              stream=readable)

    @classmethod
    def _parseWasbUrl(cls, url):
        """
        :param urlparse.ParseResult url: x
        :rtype: AzureJobStore.BlobInfo
        """
        assert url.scheme in ('wasb', 'wasbs')
        try:
            container, account = url.netloc.split('@')
        except ValueError:
            raise InvalidImportExportUrlException(url)
        suffix = '.blob.core.windows.net'
        if account.endswith(suffix):
            account = account[:-len(suffix)]
        else:
            raise InvalidImportExportUrlException(url)
        assert url.path[0] == '/'
        return cls.BlobInfo(account=account,
                            container=container,
                            name=url.path[1:])

    @classmethod
    def _supportsUrl(cls, url, export=False):
        return url.scheme.lower() in ('wasb', 'wasbs')

    def writeFile(self, localFilePath, jobStoreID=None):
        jobStoreFileID = self._newFileID()
        self.updateFile(jobStoreFileID, localFilePath)
        self._associateFileWithJob(jobStoreFileID, jobStoreID)
        return jobStoreFileID

    def updateFile(self, jobStoreFileID, localFilePath):
        with open(localFilePath) as read_fd:
            with self._uploadStream(jobStoreFileID, self.files) as write_fd:
                while True:
                    buf = read_fd.read(self._maxAzureBlockBytes)
                    write_fd.write(buf)
                    if len(buf) == 0:
                        break

    def readFile(self, jobStoreFileID, localFilePath):
        try:
            with self._downloadStream(jobStoreFileID, self.files) as read_fd:
                with open(localFilePath, 'w') as write_fd:
                    while True:
                        buf = read_fd.read(self._maxAzureBlockBytes)
                        write_fd.write(buf)
                        if not buf:
                            break
        except AzureMissingResourceHttpError:
            raise NoSuchFileException(jobStoreFileID)

    def deleteFile(self, jobStoreFileID):
        try:
            self.files.delete_blob(blob_name=jobStoreFileID)
            self._dissociateFileFromJob(jobStoreFileID)
        except AzureMissingResourceHttpError:
            pass

    def fileExists(self, jobStoreFileID):
        # As Azure doesn't have a blob_exists method (at least in the
        # python API) we just try to download the metadata, and hope
        # the metadata is small so the call will be fast.
        try:
            self.files.get_blob_metadata(blob_name=jobStoreFileID)
            return True
        except AzureMissingResourceHttpError:
            return False

    @contextmanager
    def writeFileStream(self, jobStoreID=None):
        # TODO: this (and all stream methods) should probably use the
        # Append Blob type, but that is not currently supported by the
        # Azure Python API.
        jobStoreFileID = self._newFileID()
        with self._uploadStream(jobStoreFileID, self.files) as fd:
            yield fd, jobStoreFileID
        self._associateFileWithJob(jobStoreFileID, jobStoreID)

    @contextmanager
    def updateFileStream(self, jobStoreFileID):
        with self._uploadStream(jobStoreFileID,
                                self.files,
                                checkForModification=True) as fd:
            yield fd

    def getEmptyFileStoreID(self, jobStoreID=None):
        jobStoreFileID = self._newFileID()
        self.files.put_blob(blob_name=jobStoreFileID,
                            blob='',
                            x_ms_blob_type='BlockBlob')
        self._associateFileWithJob(jobStoreFileID, jobStoreID)
        return jobStoreFileID

    @contextmanager
    def readFileStream(self, jobStoreFileID):
        if not self.fileExists(jobStoreFileID):
            raise NoSuchFileException(jobStoreFileID)
        with self._downloadStream(jobStoreFileID, self.files) as fd:
            yield fd

    @contextmanager
    def writeSharedFileStream(self, sharedFileName, isProtected=None):
        assert self._validateSharedFileName(sharedFileName)
        sharedFileID = self._newFileID(sharedFileName)
        with self._uploadStream(sharedFileID,
                                self.files,
                                encrypted=isProtected) as fd:
            yield fd

    @contextmanager
    def readSharedFileStream(self, sharedFileName):
        assert self._validateSharedFileName(sharedFileName)
        sharedFileID = self._newFileID(sharedFileName)
        if not self.fileExists(sharedFileID):
            raise NoSuchFileException(sharedFileID)
        with self._downloadStream(sharedFileID, self.files) as fd:
            yield fd

    def writeStatsAndLogging(self, statsAndLoggingString):
        # TODO: would be a great use case for the append blobs, once implemented in the Azure SDK
        jobStoreFileID = self._newFileID()
        encrypted = self.keyPath is not None
        if encrypted:
            statsAndLoggingString = encryption.encrypt(statsAndLoggingString,
                                                       self.keyPath)
        self.statsFiles.put_block_blob_from_text(
            blob_name=jobStoreFileID,
            text=statsAndLoggingString,
            x_ms_meta_name_values=dict(encrypted=str(encrypted)))
        self.statsFileIDs.insert_entity(entity={'RowKey': jobStoreFileID})

    def readStatsAndLogging(self, callback, readAll=False):
        suffix = '_old'
        numStatsFiles = 0
        for entity in self.statsFileIDs.query_entities():
            jobStoreFileID = entity.RowKey
            hasBeenRead = len(jobStoreFileID) > self.jobIDLength
            if not hasBeenRead:
                with self._downloadStream(jobStoreFileID,
                                          self.statsFiles) as fd:
                    callback(fd)
                # Mark this entity as read by appending the suffix
                self.statsFileIDs.insert_entity(
                    entity={'RowKey': jobStoreFileID + suffix})
                self.statsFileIDs.delete_entity(row_key=jobStoreFileID)
                numStatsFiles += 1
            elif readAll:
                # Strip the suffix to get the original ID
                jobStoreFileID = jobStoreFileID[:-len(suffix)]
                with self._downloadStream(jobStoreFileID,
                                          self.statsFiles) as fd:
                    callback(fd)
                numStatsFiles += 1
        return numStatsFiles

    _azureTimeFormat = "%Y-%m-%dT%H:%M:%SZ"

    def getPublicUrl(self, jobStoreFileID):
        try:
            self.files.get_blob_properties(blob_name=jobStoreFileID)
        except AzureMissingResourceHttpError:
            raise NoSuchFileException(jobStoreFileID)
        # Compensate of a little bit of clock skew
        startTimeStr = (datetime.utcnow() - timedelta(minutes=5)).strftime(
            self._azureTimeFormat)
        endTime = datetime.utcnow() + self.publicUrlExpiration
        endTimeStr = endTime.strftime(self._azureTimeFormat)
        sap = SharedAccessPolicy(
            AccessPolicy(startTimeStr, endTimeStr,
                         BlobSharedAccessPermissions.READ))
        sas_token = self.files.generate_shared_access_signature(
            blob_name=jobStoreFileID, shared_access_policy=sap)
        return self.files.make_blob_url(
            blob_name=jobStoreFileID) + '?' + sas_token

    def getSharedPublicUrl(self, sharedFileName):
        jobStoreFileID = self._newFileID(sharedFileName)
        return self.getPublicUrl(jobStoreFileID)

    def _newJobID(self):
        # raw UUIDs don't work for Azure property names because the '-' character is disallowed.
        return str(uuid.uuid4()).replace('-', '_')

    # A dummy job ID under which all shared files are stored.
    sharedFileJobID = uuid.UUID('891f7db6-e4d9-4221-a58e-ab6cc4395f94')

    def _newFileID(self, sharedFileName=None):
        if sharedFileName is None:
            ret = str(uuid.uuid4())
        else:
            ret = str(uuid.uuid5(self.sharedFileJobID, str(sharedFileName)))
        return ret.replace('-', '_')

    def _associateFileWithJob(self, jobStoreFileID, jobStoreID=None):
        if jobStoreID is not None:
            self.jobFileIDs.insert_entity(entity={
                'PartitionKey': jobStoreID,
                'RowKey': jobStoreFileID
            })

    def _dissociateFileFromJob(self, jobStoreFileID):
        entities = self.jobFileIDs.query_entities(filter="RowKey eq '%s'" %
                                                  jobStoreFileID)
        if entities:
            assert len(entities) == 1
            jobStoreID = entities[0].PartitionKey
            self.jobFileIDs.delete_entity(partition_key=jobStoreID,
                                          row_key=jobStoreFileID)

    def _bindTable(self, tableName, create=False):
        for attempt in retry_azure():
            with attempt:
                try:
                    tables = self.tableService.query_tables(
                        table_name=tableName)
                except AzureMissingResourceHttpError as e:
                    if e.status_code != 404:
                        raise
                else:
                    if tables:
                        assert tables[0].name == tableName
                        return AzureTable(self.tableService, tableName)
                if create:
                    self.tableService.create_table(tableName)
                    return AzureTable(self.tableService, tableName)
                else:
                    return None

    def _bindContainer(self, containerName, create=False):
        for attempt in retry_azure():
            with attempt:
                try:
                    self.blobService.get_container_properties(containerName)
                except AzureMissingResourceHttpError as e:
                    if e.status_code == 404:
                        if create:
                            self.blobService.create_container(containerName)
                        else:
                            return None
                    else:
                        raise
        return AzureBlobContainer(self.blobService, containerName)

    def _sanitizeTableName(self, tableName):
        """
        Azure table names must start with a letter and be alphanumeric.

        This will never cause a collision if uuids are used, but
        otherwise may not be safe.
        """
        return 'a' + filter(lambda x: x.isalnum(), tableName)

    # Maximum bytes that can be in any block of an Azure block blob
    # https://github.com/Azure/azure-storage-python/blob/4c7666e05a9556c10154508335738ee44d7cb104/azure/storage/blob/blobservice.py#L106
    _maxAzureBlockBytes = 4 * 1024 * 1024

    @contextmanager
    def _uploadStream(self,
                      jobStoreFileID,
                      container,
                      checkForModification=False,
                      encrypted=None):
        """
        :param encrypted: True to enforce encryption (will raise exception unless key is set),
        False to prevent encryption or None to encrypt if key is set.
        """
        if checkForModification:
            try:
                expectedVersion = container.get_blob_properties(
                    blob_name=jobStoreFileID)['etag']
            except AzureMissingResourceHttpError:
                expectedVersion = None

        if encrypted is None:
            encrypted = self.keyPath is not None
        elif encrypted:
            if self.keyPath is None:
                raise RuntimeError(
                    'Encryption requested but no key was provided')

        maxBlockSize = self._maxAzureBlockBytes
        if encrypted:
            # There is a small overhead for encrypted data.
            maxBlockSize -= encryption.overhead

        store = self

        class UploadPipe(WritablePipe):
            def readFrom(self, readable):
                blockIDs = []
                try:
                    while True:
                        buf = readable.read(maxBlockSize)
                        if len(buf) == 0:
                            # We're safe to break here even if we never read anything, since
                            # putting an empty block list creates an empty blob.
                            break
                        if encrypted:
                            buf = encryption.encrypt(buf, store.keyPath)
                        blockID = store._newFileID()
                        container.put_block(blob_name=jobStoreFileID,
                                            block=buf,
                                            blockid=blockID)
                        blockIDs.append(blockID)
                except:
                    with panic(log=logger):
                        # This is guaranteed to delete any uncommitted blocks.
                        container.delete_blob(blob_name=jobStoreFileID)

                if checkForModification and expectedVersion is not None:
                    # Acquire a (60-second) write lock,
                    leaseID = container.lease_blob(
                        blob_name=jobStoreFileID,
                        x_ms_lease_action='acquire')['x-ms-lease-id']
                    # check for modification,
                    blobProperties = container.get_blob_properties(
                        blob_name=jobStoreFileID)
                    if blobProperties['etag'] != expectedVersion:
                        container.lease_blob(blob_name=jobStoreFileID,
                                             x_ms_lease_action='release',
                                             x_ms_lease_id=leaseID)
                        raise ConcurrentFileModificationException(
                            jobStoreFileID)
                    # commit the file,
                    container.put_block_list(
                        blob_name=jobStoreFileID,
                        block_list=blockIDs,
                        x_ms_lease_id=leaseID,
                        x_ms_meta_name_values=dict(encrypted=str(encrypted)))
                    # then release the lock.
                    container.lease_blob(blob_name=jobStoreFileID,
                                         x_ms_lease_action='release',
                                         x_ms_lease_id=leaseID)
                else:
                    # No need to check for modification, just blindly write over whatever
                    # was there.
                    container.put_block_list(
                        blob_name=jobStoreFileID,
                        block_list=blockIDs,
                        x_ms_meta_name_values=dict(encrypted=str(encrypted)))

        with UploadPipe() as writable:
            yield writable

    @contextmanager
    def _downloadStream(self, jobStoreFileID, container):
        # The reason this is not in the writer is so we catch non-existant blobs early

        blobProps = container.get_blob_properties(blob_name=jobStoreFileID)

        encrypted = strict_bool(blobProps['x-ms-meta-encrypted'])
        if encrypted and self.keyPath is None:
            raise AssertionError(
                'Content is encrypted but no key was provided.')

        outer_self = self

        class DownloadPipe(ReadablePipe):
            def writeTo(self, writable):
                chunkStart = 0
                fileSize = int(blobProps['Content-Length'])
                while chunkStart < fileSize:
                    chunkEnd = chunkStart + outer_self._maxAzureBlockBytes - 1
                    buf = container.get_blob(blob_name=jobStoreFileID,
                                             x_ms_range="bytes=%d-%d" %
                                             (chunkStart, chunkEnd))
                    if encrypted:
                        buf = encryption.decrypt(buf, outer_self.keyPath)
                    writable.write(buf)
                    chunkStart = chunkEnd + 1

        with DownloadPipe() as readable:
            yield readable
Exemple #3
0
class AzureJobStore(AbstractJobStore):
    """
    A job store that uses Azure's blob store for file storage and Table Service to store job info
    with strong consistency.
    """

    # Dots in container names should be avoided because container names are used in HTTPS bucket
    # URLs where the may interfere with the certificate common name. We use a double underscore
    # as a separator instead.
    #
    containerNameRe = re.compile(r'^[a-z0-9](-?[a-z0-9]+)+[a-z0-9]$')

    # See https://msdn.microsoft.com/en-us/library/azure/dd135715.aspx
    #
    minContainerNameLen = 3
    maxContainerNameLen = 63
    maxNameLen = 10
    nameSeparator = 'xx'  # Table names must be alphanumeric
    # Length of a jobID - used to test if a stats file has been read already or not
    jobIDLength = len(str(uuid.uuid4()))

    def __init__(self, locator, jobChunkSize=maxAzureTablePropertySize):
        super(AzureJobStore, self).__init__()
        accountName, namePrefix = locator.split(':', 1)
        if '--' in namePrefix:
            raise ValueError("Invalid name prefix '%s'. Name prefixes may not contain %s."
                             % (namePrefix, self.nameSeparator))
        if not self.containerNameRe.match(namePrefix):
            raise ValueError("Invalid name prefix '%s'. Name prefixes must contain only digits, "
                             "hyphens or lower-case letters and must not start or end in a "
                             "hyphen." % namePrefix)
        # Reserve 13 for separator and suffix
        if len(namePrefix) > self.maxContainerNameLen - self.maxNameLen - len(self.nameSeparator):
            raise ValueError(("Invalid name prefix '%s'. Name prefixes may not be longer than 50 "
                              "characters." % namePrefix))
        if '--' in namePrefix:
            raise ValueError("Invalid name prefix '%s'. Name prefixes may not contain "
                             "%s." % (namePrefix, self.nameSeparator))
        self.locator = locator
        self.jobChunkSize = jobChunkSize
        self.accountKey = _fetchAzureAccountKey(accountName)
        self.accountName = accountName
        # Table names have strict requirements in Azure
        self.namePrefix = self._sanitizeTableName(namePrefix)
        # These are the main API entry points.
        self.tableService = TableService(account_key=self.accountKey, account_name=accountName)
        self.blobService = BlobService(account_key=self.accountKey, account_name=accountName)
        # Serialized jobs table
        self.jobItems = None
        # Job<->file mapping table
        self.jobFileIDs = None
        # Container for all shared and unshared files
        self.files = None
        # Stats and logging strings
        self.statsFiles = None
        # File IDs that contain stats and logging strings
        self.statsFileIDs = None

    @property
    def keyPath(self):
        return self.config.cseKey

    def initialize(self, config):
        if self._jobStoreExists():
            raise JobStoreExistsException(self.locator)
        logger.debug("Creating job store at '%s'" % self.locator)
        self._bind(create=True)
        super(AzureJobStore, self).initialize(config)

    def resume(self):
        if not self._jobStoreExists():
            raise NoSuchJobStoreException(self.locator)
        logger.debug("Using existing job store at '%s'" % self.locator)
        self._bind(create=False)
        super(AzureJobStore, self).resume()

    def destroy(self):
        self._bind()
        for name in 'jobItems', 'jobFileIDs', 'files', 'statsFiles', 'statsFileIDs':
            resource = getattr(self, name)
            if resource is not None:
                if isinstance(resource, AzureTable):
                    resource.delete_table()
                elif isinstance(resource, AzureBlobContainer):
                    resource.delete_container()
                else:
                    assert False
                setattr(self, name, None)

    def _jobStoreExists(self):
        """
        Checks if job store exists by querying the existence of the statsFileIDs table. Note that
        this is the last component that is deleted in :meth:`.destroy`.
        """
        for attempt in retry_azure():
            with attempt:
                try:
                    table = self.tableService.query_tables(table_name=self._qualify('statsFileIDs'))
                except AzureMissingResourceHttpError as e:
                    if e.status_code == 404:
                        return False
                    else:
                        raise
                else:
                    return table is not None

    def _bind(self, create=False):
        table = self._bindTable
        container = self._bindContainer
        for name, binder in (('jobItems', table),
                             ('jobFileIDs', table),
                             ('files', container),
                             ('statsFiles', container),
                             ('statsFileIDs', table)):
            if getattr(self, name) is None:
                setattr(self, name, binder(self._qualify(name), create=create))

    def _qualify(self, name):
        return self.namePrefix + self.nameSeparator + name.lower()

    def jobs(self):

        # How many jobs have we done?
        total_processed = 0

        for jobEntity in self.jobItems.query_entities_auto():
            # Process the items in the page
            yield AzureJob.fromEntity(jobEntity)
            total_processed += 1

            if total_processed % 1000 == 0:
                # Produce some feedback for the user, because this can take
                # a long time on, for example, Azure
                logger.debug("Processed %d total jobs" % total_processed)

        logger.debug("Processed %d total jobs" % total_processed)

    def create(self, jobNode):
        jobStoreID = self._newJobID()
        job = AzureJob.fromJobNode(jobNode, jobStoreID, self._defaultTryCount())
        entity = job.toItem(chunkSize=self.jobChunkSize)
        entity['RowKey'] = EntityProperty('Edm.String', jobStoreID)
        self.jobItems.insert_entity(entity=entity)
        return job

    def exists(self, jobStoreID):
        if self.jobItems.get_entity(row_key=bytes(jobStoreID)) is None:
            return False
        return True

    def load(self, jobStoreID):
        jobEntity = self.jobItems.get_entity(row_key=bytes(jobStoreID))
        if jobEntity is None:
            raise NoSuchJobException(jobStoreID)
        return AzureJob.fromEntity(jobEntity)

    def update(self, job):
        self.jobItems.update_entity(row_key=bytes(job.jobStoreID),
                                    entity=job.toItem(chunkSize=self.jobChunkSize))

    def delete(self, jobStoreID):
        try:
            self.jobItems.delete_entity(row_key=bytes(jobStoreID))
        except AzureMissingResourceHttpError:
            # Job deletion is idempotent, and this job has been deleted already
            return
        filterString = "PartitionKey eq '%s'" % jobStoreID
        for fileEntity in self.jobFileIDs.query_entities(filter=filterString):
            jobStoreFileID = fileEntity.RowKey
            self.deleteFile(jobStoreFileID)

    def getEnv(self):
        return dict(AZURE_ACCOUNT_KEY=self.accountKey)

    class BlobInfo(namedtuple('BlobInfo', ('account', 'container', 'name'))):
        @property
        @memoize
        def service(self):
            return BlobService(account_name=self.account,
                               account_key=_fetchAzureAccountKey(self.account))

    @classmethod
    def getSize(cls, url):
        blob = cls._parseWasbUrl(url)
        blobProps = blob.service.get_blob_properties(blob.container, blob.name)
        return int(blobProps['content-length'])

    @classmethod
    def _readFromUrl(cls, url, writable):
        blob = cls._parseWasbUrl(url)
        for attempt in retry_azure():
            with attempt:
                blob.service.get_blob_to_file(container_name=blob.container,
                                              blob_name=blob.name,
                                              stream=writable)

    @classmethod
    def _writeToUrl(cls, readable, url):
        blob = cls._parseWasbUrl(url)
        blob.service.put_block_blob_from_file(container_name=blob.container,
                                              blob_name=blob.name,
                                              stream=readable)

    @classmethod
    def _parseWasbUrl(cls, url):
        """
        :param urlparse.ParseResult url: x
        :rtype: AzureJobStore.BlobInfo
        """
        assert url.scheme in ('wasb', 'wasbs')
        try:
            container, account = url.netloc.split('@')
        except ValueError:
            raise InvalidImportExportUrlException(url)
        suffix = '.blob.core.windows.net'
        if account.endswith(suffix):
            account = account[:-len(suffix)]
        else:
            raise InvalidImportExportUrlException(url)
        assert url.path[0] == '/'
        return cls.BlobInfo(account=account, container=container, name=url.path[1:])

    @classmethod
    def _supportsUrl(cls, url, export=False):
        return url.scheme.lower() in ('wasb', 'wasbs')

    def writeFile(self, localFilePath, jobStoreID=None):
        jobStoreFileID = self._newFileID()
        self.updateFile(jobStoreFileID, localFilePath)
        self._associateFileWithJob(jobStoreFileID, jobStoreID)
        return jobStoreFileID

    def updateFile(self, jobStoreFileID, localFilePath):
        with open(localFilePath) as read_fd:
            with self._uploadStream(jobStoreFileID, self.files) as write_fd:
                while True:
                    buf = read_fd.read(self._maxAzureBlockBytes)
                    write_fd.write(buf)
                    if len(buf) == 0:
                        break

    def readFile(self, jobStoreFileID, localFilePath):
        try:
            with self._downloadStream(jobStoreFileID, self.files) as read_fd:
                with open(localFilePath, 'w') as write_fd:
                    while True:
                        buf = read_fd.read(self._maxAzureBlockBytes)
                        write_fd.write(buf)
                        if not buf:
                            break
        except AzureMissingResourceHttpError:
            raise NoSuchFileException(jobStoreFileID)

    def deleteFile(self, jobStoreFileID):
        try:
            self.files.delete_blob(blob_name=bytes(jobStoreFileID))
            self._dissociateFileFromJob(jobStoreFileID)
        except AzureMissingResourceHttpError:
            pass

    def fileExists(self, jobStoreFileID):
        # As Azure doesn't have a blob_exists method (at least in the
        # python API) we just try to download the metadata, and hope
        # the metadata is small so the call will be fast.
        try:
            self.files.get_blob_metadata(blob_name=bytes(jobStoreFileID))
            return True
        except AzureMissingResourceHttpError:
            return False

    @contextmanager
    def writeFileStream(self, jobStoreID=None):
        # TODO: this (and all stream methods) should probably use the
        # Append Blob type, but that is not currently supported by the
        # Azure Python API.
        jobStoreFileID = self._newFileID()
        with self._uploadStream(jobStoreFileID, self.files) as fd:
            yield fd, jobStoreFileID
        self._associateFileWithJob(jobStoreFileID, jobStoreID)

    @contextmanager
    def updateFileStream(self, jobStoreFileID):
        with self._uploadStream(jobStoreFileID, self.files, checkForModification=True) as fd:
            yield fd

    def getEmptyFileStoreID(self, jobStoreID=None):
        jobStoreFileID = self._newFileID()
        with self._uploadStream(jobStoreFileID, self.files) as _:
            pass
        self._associateFileWithJob(jobStoreFileID, jobStoreID)
        return jobStoreFileID

    @contextmanager
    def readFileStream(self, jobStoreFileID):
        if not self.fileExists(jobStoreFileID):
            raise NoSuchFileException(jobStoreFileID)
        with self._downloadStream(jobStoreFileID, self.files) as fd:
            yield fd

    @contextmanager
    def writeSharedFileStream(self, sharedFileName, isProtected=None):
        assert self._validateSharedFileName(sharedFileName)
        sharedFileID = self._newFileID(sharedFileName)
        with self._uploadStream(sharedFileID, self.files, encrypted=isProtected) as fd:
            yield fd

    @contextmanager
    def readSharedFileStream(self, sharedFileName):
        assert self._validateSharedFileName(sharedFileName)
        sharedFileID = self._newFileID(sharedFileName)
        if not self.fileExists(sharedFileID):
            raise NoSuchFileException(sharedFileID)
        with self._downloadStream(sharedFileID, self.files) as fd:
            yield fd

    def writeStatsAndLogging(self, statsAndLoggingString):
        # TODO: would be a great use case for the append blobs, once implemented in the Azure SDK
        jobStoreFileID = self._newFileID()
        encrypted = self.keyPath is not None
        if encrypted:
            statsAndLoggingString = encryption.encrypt(statsAndLoggingString, self.keyPath)
        self.statsFiles.put_block_blob_from_text(blob_name=bytes(jobStoreFileID),
                                                 text=statsAndLoggingString,
                                                 x_ms_meta_name_values=dict(
                                                     encrypted=str(encrypted)))
        self.statsFileIDs.insert_entity(entity={'RowKey': jobStoreFileID})

    def readStatsAndLogging(self, callback, readAll=False):
        suffix = '_old'
        numStatsFiles = 0
        for attempt in retry_azure():
            with attempt:
                for entity in self.statsFileIDs.query_entities():
                    jobStoreFileID = entity.RowKey
                    hasBeenRead = len(jobStoreFileID) > self.jobIDLength
                    if not hasBeenRead:
                        with self._downloadStream(jobStoreFileID, self.statsFiles) as fd:
                            callback(fd)
                        # Mark this entity as read by appending the suffix
                        self.statsFileIDs.insert_entity(entity={'RowKey': jobStoreFileID + suffix})
                        self.statsFileIDs.delete_entity(row_key=bytes(jobStoreFileID))
                        numStatsFiles += 1
                    elif readAll:
                        # Strip the suffix to get the original ID
                        jobStoreFileID = jobStoreFileID[:-len(suffix)]
                        with self._downloadStream(jobStoreFileID, self.statsFiles) as fd:
                            callback(fd)
                        numStatsFiles += 1
        return numStatsFiles

    _azureTimeFormat = "%Y-%m-%dT%H:%M:%SZ"

    def getPublicUrl(self, jobStoreFileID):
        try:
            self.files.get_blob_properties(blob_name=bytes(jobStoreFileID))
        except AzureMissingResourceHttpError:
            raise NoSuchFileException(jobStoreFileID)
        # Compensate of a little bit of clock skew
        startTimeStr = (datetime.utcnow() - timedelta(minutes=5)).strftime(self._azureTimeFormat)
        endTime = datetime.utcnow() + self.publicUrlExpiration
        endTimeStr = endTime.strftime(self._azureTimeFormat)
        sap = SharedAccessPolicy(AccessPolicy(startTimeStr, endTimeStr,
                                              BlobSharedAccessPermissions.READ))
        sas_token = self.files.generate_shared_access_signature(blob_name=bytes(jobStoreFileID),
                                                                shared_access_policy=sap)
        return self.files.make_blob_url(blob_name=bytes(jobStoreFileID)) + '?' + sas_token

    def getSharedPublicUrl(self, sharedFileName):
        jobStoreFileID = self._newFileID(sharedFileName)
        return self.getPublicUrl(jobStoreFileID)

    def _newJobID(self):
        # raw UUIDs don't work for Azure property names because the '-' character is disallowed.
        return str(uuid.uuid4()).replace('-', '_')

    # A dummy job ID under which all shared files are stored.
    sharedFileJobID = uuid.UUID('891f7db6-e4d9-4221-a58e-ab6cc4395f94')

    def _newFileID(self, sharedFileName=None):
        if sharedFileName is None:
            ret = bytes(uuid.uuid4())
        else:
            ret = bytes(uuid.uuid5(self.sharedFileJobID, bytes(sharedFileName)))
        return ret.replace('-', '_')

    def _associateFileWithJob(self, jobStoreFileID, jobStoreID=None):
        if jobStoreID is not None:
            self.jobFileIDs.insert_entity(entity={'PartitionKey': EntityProperty('Edm.String', jobStoreID),
                                                  'RowKey': EntityProperty('Edm.String', jobStoreFileID)})

    def _dissociateFileFromJob(self, jobStoreFileID):
        entities = self.jobFileIDs.query_entities(filter="RowKey eq '%s'" % jobStoreFileID)
        if entities:
            assert len(entities) == 1
            jobStoreID = entities[0].PartitionKey
            self.jobFileIDs.delete_entity(partition_key=bytes(jobStoreID), row_key=bytes(jobStoreFileID))

    def _bindTable(self, tableName, create=False):
        for attempt in retry_azure():
            with attempt:
                try:
                    tables = self.tableService.query_tables(table_name=tableName)
                except AzureMissingResourceHttpError as e:
                    if e.status_code != 404:
                        raise
                else:
                    if tables:
                        assert tables[0].name == tableName
                        return AzureTable(self.tableService, tableName)
                if create:
                    self.tableService.create_table(tableName)
                    return AzureTable(self.tableService, tableName)
                else:
                    return None

    def _bindContainer(self, containerName, create=False):
        for attempt in retry_azure():
            with attempt:
                try:
                    self.blobService.get_container_properties(containerName)
                except AzureMissingResourceHttpError as e:
                    if e.status_code == 404:
                        if create:
                            self.blobService.create_container(containerName)
                        else:
                            return None
                    else:
                        raise
        return AzureBlobContainer(self.blobService, containerName)

    def _sanitizeTableName(self, tableName):
        """
        Azure table names must start with a letter and be alphanumeric.

        This will never cause a collision if uuids are used, but
        otherwise may not be safe.
        """
        return 'a' + ''.join([x for x in tableName if x.isalnum()])

    # Maximum bytes that can be in any block of an Azure block blob
    # https://github.com/Azure/azure-storage-python/blob/4c7666e05a9556c10154508335738ee44d7cb104/azure/storage/blob/blobservice.py#L106
    _maxAzureBlockBytes = 4 * 1024 * 1024

    @contextmanager
    def _uploadStream(self, jobStoreFileID, container, checkForModification=False, encrypted=None):
        """
        :param encrypted: True to enforce encryption (will raise exception unless key is set),
        False to prevent encryption or None to encrypt if key is set.
        """
        if checkForModification:
            try:
                expectedVersion = container.get_blob_properties(blob_name=bytes(jobStoreFileID))['etag']
            except AzureMissingResourceHttpError:
                expectedVersion = None

        if encrypted is None:
            encrypted = self.keyPath is not None
        elif encrypted:
            if self.keyPath is None:
                raise RuntimeError('Encryption requested but no key was provided')

        maxBlockSize = self._maxAzureBlockBytes
        if encrypted:
            # There is a small overhead for encrypted data.
            maxBlockSize -= encryption.overhead

        store = self

        class UploadPipe(WritablePipe):

            def readFrom(self, readable):
                blockIDs = []
                try:
                    while True:
                        buf = readable.read(maxBlockSize)
                        if len(buf) == 0:
                            # We're safe to break here even if we never read anything, since
                            # putting an empty block list creates an empty blob.
                            break
                        if encrypted:
                            buf = encryption.encrypt(buf, store.keyPath)
                        blockID = store._newFileID()
                        container.put_block(blob_name=bytes(jobStoreFileID),
                                            block=buf,
                                            blockid=blockID)
                        blockIDs.append(blockID)
                except:
                    with panic(log=logger):
                        # This is guaranteed to delete any uncommitted blocks.
                        container.delete_blob(blob_name=bytes(jobStoreFileID))

                if checkForModification and expectedVersion is not None:
                    # Acquire a (60-second) write lock,
                    leaseID = container.lease_blob(blob_name=bytes(jobStoreFileID),
                                                   x_ms_lease_action='acquire')['x-ms-lease-id']
                    # check for modification,
                    blobProperties = container.get_blob_properties(blob_name=bytes(jobStoreFileID))
                    if blobProperties['etag'] != expectedVersion:
                        container.lease_blob(blob_name=bytes(jobStoreFileID),
                                             x_ms_lease_action='release',
                                             x_ms_lease_id=leaseID)
                        raise ConcurrentFileModificationException(jobStoreFileID)
                    # commit the file,
                    container.put_block_list(blob_name=bytes(jobStoreFileID),
                                             block_list=blockIDs,
                                             x_ms_lease_id=leaseID,
                                             x_ms_meta_name_values=dict(
                                                 encrypted=str(encrypted)))
                    # then release the lock.
                    container.lease_blob(blob_name=bytes(jobStoreFileID),
                                         x_ms_lease_action='release',
                                         x_ms_lease_id=leaseID)
                else:
                    # No need to check for modification, just blindly write over whatever
                    # was there.
                    container.put_block_list(blob_name=bytes(jobStoreFileID),
                                             block_list=blockIDs,
                                             x_ms_meta_name_values=dict(encrypted=str(encrypted)))

        with UploadPipe() as writable:
            yield writable

    @contextmanager
    def _downloadStream(self, jobStoreFileID, container):
        # The reason this is not in the writer is so we catch non-existant blobs early

        blobProps = container.get_blob_properties(blob_name=bytes(jobStoreFileID))

        encrypted = strict_bool(blobProps['x-ms-meta-encrypted'])
        if encrypted and self.keyPath is None:
            raise AssertionError('Content is encrypted but no key was provided.')

        outer_self = self

        class DownloadPipe(ReadablePipe):
            def writeTo(self, writable):
                chunkStart = 0
                fileSize = int(blobProps['Content-Length'])
                while chunkStart < fileSize:
                    chunkEnd = chunkStart + outer_self._maxAzureBlockBytes - 1
                    buf = container.get_blob(blob_name=bytes(jobStoreFileID),
                                             x_ms_range="bytes=%d-%d" % (chunkStart, chunkEnd))
                    if encrypted:
                        buf = encryption.decrypt(buf, outer_self.keyPath)
                    writable.write(buf)
                    chunkStart = chunkEnd + 1

        with DownloadPipe() as readable:
            yield readable
Exemple #4
0
class TableStorageHandlerTest(_TestCase):
    def _divide_key(self, key):
        divided = []
        hostname = gethostname()
        if key.find(hostname) >= 0:
            preceding, hostname, remaining = key.rpartition(hostname)
            preceding = preceding[:-1] if preceding.endswith(
                '-') else preceding
            divided.extend(preceding.split('-'))
            divided.append(hostname)
            remaining = remaining[1:] if remaining.startswith(
                '-') else remaining
            divided.extend(remaining.split('-'))
        else:
            divided.extend(key.split('-'))
        return iter(divided)

    def _get_formatter_name(self, handler_name, formatter_type):
        name = _get_handler_config_value(handler_name, formatter_type)
        if name:
            if name.startswith('cfg://formatters.'):
                name = name.split('.')[1]
        return name

    def _get_partition_key_formatter_name(self, handler_name):
        return self._get_formatter_name(handler_name,
                                        'partition_key_formatter')

    def _get_row_key_formatter_name(self, handler_name):
        return self._get_formatter_name(handler_name, 'row_key_formatter')

    def setUp(self):
        self.service = TableService(ACCOUNT_NAME, ACCOUNT_KEY)
        # ensure that there's no entity in the table before each test
        tables = set()
        for cfg in LOGGING['handlers'].values():
            if 'table' in cfg:
                tables.add(cfg['table'])
        for table in self.service.query_tables():
            if table.name in tables:
                for entity in self.service.query_entities(table.name):
                    self.service.delete_entity(table.name, entity.PartitionKey,
                                               entity.RowKey)

    def test_logging(self):
        # get the logger for the test
        logger_name = 'table'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)

        # perform logging
        log_text = 'logging test'
        logging_started = datetime.now()
        logger.info(log_text)
        logging_finished = datetime.now()

        # confirm that the entity has correct log text
        table = _get_handler_config_value(handler_name, 'table')
        entities = iter(self.service.query_entities(table))
        entity = next(entities)
        self.assertEqual(entity.message, 'INFO %s' % log_text)

        # confirm that the entity has the default partitiok key
        fmt = '%Y%m%d%H%M'
        try:
            self.assertEqual(entity.PartitionKey,
                             logging_started.strftime(fmt))
        except AssertionError:
            if logging_started == logging_finished:
                raise
            self.assertEqual(entity.PartitionKey,
                             logging_finished.strftime(fmt))

        # confirm that the entity has the default row key
        divided = self._divide_key(entity.RowKey)
        timestamp = next(divided)
        fmt = '%Y%m%d%H%M%S'
        self.assertGreaterEqual(timestamp[:-3], logging_started.strftime(fmt))
        self.assertLessEqual(timestamp[:-3], logging_finished.strftime(fmt))
        self.assertRegex(timestamp[-3:], '^[0-9]{3}$')
        self.assertEqual(next(divided), gethostname())
        self.assertEqual(int(next(divided)), os.getpid())
        self.assertEqual(next(divided), '00')
        with self.assertRaises(StopIteration):
            next(divided)

        # confirm that there's no more entity in the table
        with self.assertRaises(StopIteration):
            next(entities)

    @unittest.skipIf(_EMULATED,
                     "Azure Storage Emulator doesn't support batch operation.")
    def test_batch(self):
        # get the logger for the test
        logger_name = 'batch'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)

        # perform logging and execute  the first batch
        batch_size = _get_handler_config_value(handler_name, 'batch_size')
        log_text = 'batch logging test'
        for i in range(batch_size + int(batch_size / 2)):
            logger.info('%s#%02d' % (log_text, i))

        # confirm that only batch_size entities are committed at this point
        table = _get_handler_config_value(handler_name, 'table')
        entities = list(iter(self.service.query_entities(table)))
        self.assertEqual(len(entities), batch_size)
        rowno_found = set()
        seq_found = set()
        for entity in entities:
            # partition key
            self.assertEqual(entity.PartitionKey, 'batch-%s' % gethostname())
            # row key
            rowno = entity.RowKey.split('-')[-1]
            self.assertLess(int(rowno), batch_size)
            self.assertNotIn(rowno, rowno_found)
            rowno_found.add(rowno)
            # message
            message, seq = entity.message.split('#')
            self.assertEqual(message, 'INFO %s' % log_text)
            self.assertLess(int(seq), batch_size)
            self.assertNotIn(seq, seq_found)
            seq_found.add(seq)

        # remove currently created entities before the next batch
        for entity in entities:
            self.service.delete_entity(table, entity.PartitionKey,
                                       entity.RowKey)

        # perform logging again and execute the next batch
        for j in range(i + 1, int(batch_size / 2) + i + 1):
            logger.info('%s#%02d' % (log_text, j))

        # confirm that the remaining entities are committed in the next batch
        entities = list(iter(self.service.query_entities(table)))
        self.assertEqual(len(entities), batch_size)
        rowno_found.clear()
        for entity in entities:
            # partition key
            self.assertEqual(entity.PartitionKey, 'batch-%s' % gethostname())
            # row key
            rowno = entity.RowKey.split('-')[-1]
            self.assertLess(int(rowno), batch_size)
            self.assertNotIn(rowno, rowno_found)
            rowno_found.add(rowno)
            # message
            message, seq = entity.message.split('#')
            self.assertEqual(message, 'INFO %s' % log_text)
            self.assertGreaterEqual(int(seq), batch_size)
            self.assertLess(int(seq), batch_size * 2)
            self.assertNotIn(seq, seq_found)
            seq_found.add(seq)

    def test_extra_properties(self):
        # get the logger for the test
        logger_name = 'extra_properties'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)

        # perform logging
        log_text = 'extra properties test'
        logger.info(log_text)

        # confirm that the entity has correct log text
        table = _get_handler_config_value(handler_name, 'table')
        entities = iter(self.service.query_entities(table))
        entity = next(entities)
        self.assertEqual(entity.message, 'INFO %s' % log_text)

        # confirm that the extra properties have correct values
        entity = next(iter(self.service.query_entities(table)))
        self.assertEqual(entity.hostname, gethostname())
        self.assertEqual(entity.levelname, 'INFO')
        self.assertEqual(int(entity.levelno), logging.INFO)
        self.assertEqual(entity.module,
                         os.path.basename(__file__).rpartition('.')[0])
        self.assertEqual(entity.name, logger_name)
        self.assertEqual(int(entity.process), os.getpid())
        self.assertEqual(int(entity.thread), current_thread().ident)

        # confirm that there's no more entity in the table
        with self.assertRaises(StopIteration):
            next(entities)

    def test_custom_key_formatters(self):
        # get the logger for the test
        logger_name = 'custom_keys'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)

        # perform logging
        log_text = 'custom key formatters test'
        logging_started = datetime.now()
        logger.info(log_text)
        logging_finished = datetime.now()

        # confirm that the entity correct log text
        table = _get_handler_config_value(handler_name, 'table')
        entities = iter(self.service.query_entities(table))
        entity = next(entities)
        self.assertEqual(entity.message, 'INFO %s' % log_text)

        # confirm that the entity has a custom partitiok key
        divided = self._divide_key(entity.PartitionKey)
        self.assertEqual(next(divided), 'mycustompartitionkey')
        self.assertEqual(next(divided), gethostname())
        formatter_name = self._get_partition_key_formatter_name(handler_name)
        fmt = _get_formatter_config_value(formatter_name, 'datefmt')
        asctime = next(divided)
        try:
            self.assertEqual(asctime, logging_started.strftime(fmt))
        except AssertionError:
            if logging_started == logging_finished:
                raise
            self.assertEqual(asctime, logging_finished.strftime(fmt))
        with self.assertRaises(StopIteration):
            next(divided)

        # confirm that the entity has a custom row key
        divided = self._divide_key(entity.RowKey)
        self.assertEqual(next(divided), 'mycustomrowkey')
        self.assertEqual(next(divided), gethostname())
        formatter_name = self._get_row_key_formatter_name(handler_name)
        fmt = _get_formatter_config_value(formatter_name, 'datefmt')
        asctime = next(divided)
        try:
            self.assertEqual(asctime, logging_started.strftime(fmt))
        except AssertionError:
            if logging_started == logging_finished:
                raise
            self.assertEqual(asctime, logging_finished.strftime(fmt))
        with self.assertRaises(StopIteration):
            next(divided)

        # confirm that there's no more entity in the table
        with self.assertRaises(StopIteration):
            next(entities)
class TableStorageHandlerTest(_TestCase):

    def _divide_key(self, key):
        divided = []
        hostname = gethostname()
        if key.find(hostname) >= 0:
            preceding, hostname, remaining = key.rpartition(hostname)
            preceding = preceding[:-1] if preceding.endswith('-') else preceding
            divided.extend(preceding.split('-'))
            divided.append(hostname)
            remaining = remaining[1:] if remaining.startswith('-') else remaining
            divided.extend(remaining.split('-'))
        else:
            divided.extend(key.split('-'))
        return iter(divided)

    def _get_formatter_name(self, handler_name, formatter_type):
        name = _get_handler_config_value(handler_name, formatter_type)
        if name:
            if name.startswith('cfg://formatters.'):
                name = name.split('.')[1]
        return name

    def _get_partition_key_formatter_name(self, handler_name):
        return self._get_formatter_name(handler_name, 'partition_key_formatter')

    def _get_row_key_formatter_name(self, handler_name):
        return self._get_formatter_name(handler_name, 'row_key_formatter')

    def setUp(self):
        self.service = TableService(ACCOUNT_NAME, ACCOUNT_KEY)
        # ensure that there's no entity in the table before each test
        tables = set()
        for cfg in LOGGING['handlers'].values():
            if 'table' in cfg:
                tables.add(cfg['table'])
        for table in self.service.query_tables():
            if table.name in tables:
                for entity in self.service.query_entities(table.name):
                    self.service.delete_entity(table.name,
                                               entity.PartitionKey,
                                               entity.RowKey)

    def test_logging(self):
        # get the logger for the test
        logger_name = 'table'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)

        # perform logging
        log_text = 'logging test'
        logging_started = datetime.now()
        logger.info(log_text)
        logging_finished = datetime.now()

        # confirm that the entity has correct log text
        table = _get_handler_config_value(handler_name, 'table')
        entities = iter(self.service.query_entities(table))
        entity = next(entities)
        self.assertEqual(entity.message, 'INFO %s' % log_text)

        # confirm that the entity has the default partitiok key
        fmt = '%Y%m%d%H%M'
        try:
            self.assertEqual(entity.PartitionKey, logging_started.strftime(fmt))
        except AssertionError:
            if logging_started == logging_finished:
                raise
            self.assertEqual(entity.PartitionKey, logging_finished.strftime(fmt))

        # confirm that the entity has the default row key
        divided = self._divide_key(entity.RowKey)
        timestamp = next(divided)
        fmt = '%Y%m%d%H%M%S'
        self.assertGreaterEqual(timestamp[:-3], logging_started.strftime(fmt))
        self.assertLessEqual(timestamp[:-3], logging_finished.strftime(fmt))
        self.assertRegex(timestamp[-3:], '^[0-9]{3}$')
        self.assertEqual(next(divided), gethostname())
        self.assertEqual(int(next(divided)), os.getpid())
        self.assertEqual(next(divided), '00')
        with self.assertRaises(StopIteration):
            next(divided)

        # confirm that there's no more entity in the table
        with self.assertRaises(StopIteration):
            next(entities)

    @unittest.skipIf(_EMULATED, "Azure Storage Emulator doesn't support batch operation.")
    def test_batch(self):
        # get the logger for the test
        logger_name = 'batch'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)

        # perform logging and execute  the first batch
        batch_size = _get_handler_config_value(handler_name, 'batch_size')
        log_text = 'batch logging test'
        for i in range(batch_size + int(batch_size/2)):
            logger.info('%s#%02d' % (log_text, i))

        # confirm that only batch_size entities are committed at this point
        table = _get_handler_config_value(handler_name, 'table')
        entities = list(iter(self.service.query_entities(table)))
        self.assertEqual(len(entities), batch_size)
        rowno_found = set()
        seq_found = set()
        for entity in entities:
            # partition key
            self.assertEqual(entity.PartitionKey, 'batch-%s' % gethostname())
            # row key
            rowno = entity.RowKey.split('-')[-1]
            self.assertLess(int(rowno), batch_size)
            self.assertNotIn(rowno, rowno_found)
            rowno_found.add(rowno)
            # message
            message, seq = entity.message.split('#')
            self.assertEqual(message, 'INFO %s' % log_text)
            self.assertLess(int(seq), batch_size)
            self.assertNotIn(seq, seq_found)
            seq_found.add(seq)

        # remove currently created entities before the next batch
        for entity in entities:
            self.service.delete_entity(table,
                                       entity.PartitionKey,
                                       entity.RowKey)

        # perform logging again and execute the next batch
        for j in range(i+1, int(batch_size/2)+i+1):
            logger.info('%s#%02d' % (log_text, j))

        # confirm that the remaining entities are committed in the next batch
        entities = list(iter(self.service.query_entities(table)))
        self.assertEqual(len(entities), batch_size)
        rowno_found.clear()
        for entity in entities:
            # partition key
            self.assertEqual(entity.PartitionKey, 'batch-%s' % gethostname())
            # row key
            rowno = entity.RowKey.split('-')[-1]
            self.assertLess(int(rowno), batch_size)
            self.assertNotIn(rowno, rowno_found)
            rowno_found.add(rowno)
            # message
            message, seq = entity.message.split('#')
            self.assertEqual(message, 'INFO %s' % log_text)
            self.assertGreaterEqual(int(seq), batch_size)
            self.assertLess(int(seq), batch_size*2)
            self.assertNotIn(seq, seq_found)
            seq_found.add(seq)

    def test_extra_properties(self):
        # get the logger for the test
        logger_name = 'extra_properties'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)
        
        # perform logging
        log_text = 'extra properties test'
        logger.info(log_text)

        # confirm that the entity has correct log text
        table = _get_handler_config_value(handler_name, 'table')
        entities = iter(self.service.query_entities(table))
        entity = next(entities)
        self.assertEqual(entity.message, 'INFO %s' % log_text)

        # confirm that the extra properties have correct values
        entity = next(iter(self.service.query_entities(table)))
        self.assertEqual(entity.hostname, gethostname())
        self.assertEqual(entity.levelname, 'INFO')
        self.assertEqual(int(entity.levelno), logging.INFO)
        self.assertEqual(entity.module, os.path.basename(__file__).rpartition('.')[0])
        self.assertEqual(entity.name, logger_name)
        self.assertEqual(int(entity.process), os.getpid())
        self.assertEqual(int(entity.thread), current_thread().ident)

        # confirm that there's no more entity in the table
        with self.assertRaises(StopIteration):
            next(entities)

    def test_custom_key_formatters(self):
        # get the logger for the test
        logger_name = 'custom_keys'
        logger = logging.getLogger(logger_name)
        handler_name = _get_handler_name(logger_name)

        # perform logging
        log_text = 'custom key formatters test'
        logging_started = datetime.now()
        logger.info(log_text)
        logging_finished = datetime.now()

        # confirm that the entity correct log text
        table = _get_handler_config_value(handler_name, 'table')
        entities = iter(self.service.query_entities(table))
        entity = next(entities)
        self.assertEqual(entity.message, 'INFO %s' % log_text)

        # confirm that the entity has a custom partitiok key
        divided = self._divide_key(entity.PartitionKey)
        self.assertEqual(next(divided), 'mycustompartitionkey')
        self.assertEqual(next(divided), gethostname())
        formatter_name = self._get_partition_key_formatter_name(handler_name)
        fmt = _get_formatter_config_value(formatter_name, 'datefmt')
        asctime = next(divided)
        try:
            self.assertEqual(asctime, logging_started.strftime(fmt))
        except AssertionError:
            if logging_started == logging_finished:
                raise
            self.assertEqual(asctime, logging_finished.strftime(fmt))
        with self.assertRaises(StopIteration):
            next(divided)

        # confirm that the entity has a custom row key
        divided = self._divide_key(entity.RowKey)
        self.assertEqual(next(divided), 'mycustomrowkey')
        self.assertEqual(next(divided), gethostname())
        formatter_name = self._get_row_key_formatter_name(handler_name)
        fmt = _get_formatter_config_value(formatter_name, 'datefmt')
        asctime = next(divided)
        try:
            self.assertEqual(asctime, logging_started.strftime(fmt))
        except AssertionError:
            if logging_started == logging_finished:
                raise
            self.assertEqual(asctime, logging_finished.strftime(fmt))
        with self.assertRaises(StopIteration):
            next(divided)

        # confirm that there's no more entity in the table
        with self.assertRaises(StopIteration):
            next(entities)
Exemple #6
0
class az(object):
    def __init__(self,
                 default_table_name=DEFAULT_TABLE,
                 partitionKey='default'):
        self.TABLE_STORAGE_KEY = os.getenv('AZURE_STORAGE_KEY')
        self.STORAGE_NAME = os.getenv('STORAGE_NAME')
        self.default_table_name = default_table_name
        self.default_partition = partitionKey
        if self.TABLE_STORAGE_KEY == None:
            from tokens import TABLE_STORAGE_ACCESS_KEY, STORAGE_ACCOUNT_NAME
            self.TABLE_STORAGE_KEY = TABLE_STORAGE_ACCESS_KEY
            self.STORAGE_NAME = STORAGE_ACCOUNT_NAME
        self.table_service = TableService(account_name=self.STORAGE_NAME,
                                          account_key=self.TABLE_STORAGE_KEY)
        #create_table_if_does_not_exists(self.default_table_name)

    def insert_or_replace_entity_to_azure(self,
                                          rowKey,
                                          entry,
                                          t_name=DEFAULT_TABLE):
        '''
        takes table service
        
        Takes a list 
        Uploads to azure table storage 
        '''
        segment = Entity()
        segment.PartitionKey = self.default_partition
        segment.RowKey = str(rowKey).zfill(8)
        segment.latA = str(entry['latA'])
        segment.longA = str(entry['longA'])
        segment.latB = str(entry['latB'])
        segment.longB = str(entry['longB'])
        segment.colorKey = str(entry['color'])

        #print segment.colorKey

        if os.name == 'nt':
            self.table_service.insert_or_replace_entity(
                t_name, self.default_partition,
                str(rowKey).zfill(8), segment)
        else:
            self.table_service.insert_or_replace_entity(t_name, segment)

    def create_table(self, name):
        return self.table_service.create_table(name)

    def delete_table(self, name):
        return self.table_service.delete_table(name)

    def delete_entity_by_rowKey(self, rowKey, table_name=DEFAULT_TABLE):
        return self.table_service.delete_entity(table_name,
                                                self.default_partition, rowKey)

    def does_table_exist(self, table_name):
        if os.name == 'nt':
            for i in self.table_service.query_tables():
                if i.name == table_name:
                    return True
        else:
            for i in self.table_service.list_tables():
                if i.name == table_name:
                    return True
        return False

    def list_tables(self):
        if os.name == 'nt':
            for j in self.table_service.query_tables():
                print j.name
        else:
            for j in self.table_service.list_tables():
                print j.name

    def create_table_if_does_not_exist(self, table_name=DEFAULT_TABLE):
        if self.does_table_exist(table_name):
            return 'already exists'
        else:
            self.table_service.create_table(table_name)

    def create_entry(self, latA, lonA, latB, lonB, bumpiness):
        x = {
            'latA': latA,
            'longA': lonA,
            'latB': latB,
            'longB': lonB,
            'color': bumpiness
        }
        return x

    def create_random_entry(self):
        x = {
            'latA': random.uniform(37, 38),
            'longA': random.uniform(-122, -123),
            'latB': random.uniform(37, 38),
            'longB': random.uniform(-122, -123),
            'color': random.randint(0, 7)
        }
        return x

    def create_and_insert_or_replace_entity_azure(self,
                                                  latA,
                                                  lonA,
                                                  latB,
                                                  lonB,
                                                  bumpiness,
                                                  rowKey,
                                                  table_name=DEFAULT_TABLE):
        return self.insert_or_replace_entity_to_azure(
            rowKey, create_entry(latA, lonA, latB, lonB, bumpiness),
            table_name)
class az(object):
    
    def __init__(self, default_table_name=DEFAULT_TABLE, partitionKey='default'):
        self.TABLE_STORAGE_KEY = os.getenv('AZURE_STORAGE_KEY')
        self.STORAGE_NAME = os.getenv('STORAGE_NAME')
        self.default_table_name = default_table_name
        self.default_partition = partitionKey 
        if self.TABLE_STORAGE_KEY == None: 
            from tokens import TABLE_STORAGE_ACCESS_KEY, STORAGE_ACCOUNT_NAME
            self.TABLE_STORAGE_KEY = TABLE_STORAGE_ACCESS_KEY
            self.STORAGE_NAME = STORAGE_ACCOUNT_NAME
        self.table_service = TableService(account_name=self.STORAGE_NAME, account_key=self.TABLE_STORAGE_KEY)
        #create_table_if_does_not_exists(self.default_table_name)
        
    def insert_or_replace_entity_to_azure(self, rowKey, entry, t_name=DEFAULT_TABLE):
        '''
        takes table service
        
        Takes a list 
        Uploads to azure table storage 
        '''
        segment = Entity()
        segment.PartitionKey = self.default_partition
        segment.RowKey = str(rowKey).zfill(8)
        segment.latA = str(entry['latA'])
        segment.longA = str(entry['longA'])
        segment.latB = str(entry['latB'])
        segment.longB = str(entry['longB'])
        segment.colorKey = str(entry['color'])
            
        #print segment.colorKey 
        
        if os.name == 'nt':
            self.table_service.insert_or_replace_entity(t_name, self.default_partition, str(rowKey).zfill(8), segment)
        else:
            self.table_service.insert_or_replace_entity(t_name, segment) 
            
    def create_table(self, name):
        return self.table_service.create_table(name) 
        
    def delete_table(self, name):
        return self.table_service.delete_table(name)
        
    def delete_entity_by_rowKey(self, rowKey, table_name=DEFAULT_TABLE):
        return self.table_service.delete_entity(table_name, self.default_partition, rowKey)
        
        
    def does_table_exist(self, table_name):
        if os.name == 'nt':
            for i in self.table_service.query_tables():
                if i.name == table_name:
                    return True
        else:
            for i in self.table_service.list_tables():
                if i.name == table_name:
                    return True 
        return False 
        
    def list_tables(self):
        if os.name == 'nt':
            for j in self.table_service.query_tables():
                print j.name 
        else:
            for j in self.table_service.list_tables():
                print j.name 
                      
    def create_table_if_does_not_exist(self, table_name=DEFAULT_TABLE):
        if self.does_table_exist(table_name):
            return 'already exists'
        else:
            self.table_service.create_table(table_name)
            
            
    def create_entry(self, latA, lonA, latB, lonB, bumpiness):
        x = {
            'latA':latA,
            'longA':lonA,
            'latB':latB,
            'longB':lonB,
            'color': bumpiness
        }
        return x
        
    def create_random_entry(self):
        x = {
            'latA':random.uniform(37,38),
            'longA':random.uniform(-122,-123),
            'latB':random.uniform(37,38),
            'longB':random.uniform(-122,-123),
            'color': random.randint(0,7)
        }
        return x 
        
    def create_and_insert_or_replace_entity_azure(self, latA, lonA, latB, lonB, bumpiness, rowKey, table_name=DEFAULT_TABLE ):
        return self.insert_or_replace_entity_to_azure(rowKey, create_entry(latA, lonA, latB, lonB, bumpiness), table_name)