Beispiel #1
0
class UpgradeDatabaseCoreStep(object):
    """
    Base class for either schema or data upgrades on the database.

    upgrade files in sql syntax that we can execute against the database to
    accomplish the upgrade.

    @ivar sqlStore: The store to operate on.

    @type sqlStore: L{txdav.idav.IDataStore}
    """
    log = Logger()

    def __init__(self, sqlStore, uid=None, gid=None, failIfUpgradeNeeded=False, checkExistingSchema=False):
        """
        Initialize the service.
        """
        self.sqlStore = sqlStore
        self.uid = uid
        self.gid = gid
        self.failIfUpgradeNeeded = failIfUpgradeNeeded
        self.checkExistingSchema = checkExistingSchema
        self.schemaLocation = getModule(__name__).filePath.parent().parent().sibling("sql_schema")
        self.pyLocation = getModule(__name__).filePath.parent()

        self.versionKey = None
        self.versionDescriptor = ""
        self.upgradeFilePrefix = ""
        self.upgradeFileSuffix = ""
        self.defaultKeyValue = None

    def stepWithResult(self, result):
        """
        Start the service.
        """
        return self.databaseUpgrade()

    @inlineCallbacks
    def databaseUpgrade(self):
        """
        Do a database schema upgrade.
        """
        self.log.warn("Beginning database {vers} check.", vers=self.versionDescriptor)

        # Retrieve information from schema and database
        dialect, required_version, actual_version = yield self.getVersions()

        if required_version == actual_version:
            self.log.warn("{vers} version check complete: no upgrade needed.", vers=self.versionDescriptor.capitalize())
            if self.checkExistingSchema:
                if dialect == "postgres-dialect":
                    expected_schema = self.schemaLocation.child("current.sql")
                    schema_name = "public"
                else:
                    expected_schema = self.schemaLocation.child("current-oracle-dialect.sql")
                    schema_name = config.DatabaseConnection.user
                yield self.sqlStore.checkSchema(schemaFromPath(expected_schema), schema_name)
        elif required_version < actual_version:
            msg = "Actual %s version %s is more recent than the expected version %s. The service cannot be started" % (
                self.versionDescriptor, actual_version, required_version,
            )
            self.log.error(msg)
            raise RuntimeError(msg)
        elif self.failIfUpgradeNeeded:
            if self.checkExistingSchema:
                expected_schema = self.schemaLocation.child("old").child(dialect).child("v{}.sql".format(actual_version))
                if dialect == "postgres-dialect":
                    schema_name = "public"
                else:
                    schema_name = config.DatabaseConnection.user
                yield self.sqlStore.checkSchema(schemaFromPath(expected_schema), schema_name)
            raise NotAllowedToUpgrade()
        else:
            self.sqlStore.setUpgrading(True)
            yield self.upgradeVersion(actual_version, required_version, dialect)
            self.sqlStore.setUpgrading(False)

        self.log.warn("Database {vers} check complete.", vers=self.versionDescriptor)

        returnValue(None)

    @inlineCallbacks
    def getVersions(self):
        """
        Extract the expected version from the database schema and get the actual version in the current
        database, along with the DB dialect.
        """

        # Retrieve the version number from the schema file
        current_schema = self.schemaLocation.child("current.sql").getContent()
        found = re.search("insert into CALENDARSERVER values \('%s', '(\d+)'\);" % (self.versionKey,), current_schema)
        if found is None:
            msg = "Schema is missing required database key %s insert statement: %s" % (self.versionKey, current_schema,)
            self.log.error(msg)
            raise RuntimeError(msg)
        else:
            required_version = int(found.group(1))
            self.log.warn("Required database key {key}: {vers}.", key=self.versionKey, vers=required_version)

        # Get the schema version in the current database
        sqlTxn = self.sqlStore.newTransaction(label="UpgradeDatabaseCoreStep.getVersions")
        dialect = sqlTxn.dbtype.dialect
        try:
            actual_version = yield sqlTxn.calendarserverValue(self.versionKey)
            actual_version = int(actual_version)
            yield sqlTxn.commit()
        except (RuntimeError, ValueError):
            f = Failure()
            self.log.error("Database key {key} cannot be determined.", key=self.versionKey)
            yield sqlTxn.abort()
            if self.defaultKeyValue is None:
                f.raiseException()
            else:
                actual_version = self.defaultKeyValue

        self.log.warn("Actual database key {key}: {vers}.", key=self.versionKey, vers=actual_version)

        returnValue((dialect, required_version, actual_version,))

    @inlineCallbacks
    def upgradeVersion(self, fromVersion, toVersion, dialect):
        """
        Update the database from one version to another (the current one). Do this by
        looking for upgrade_from_X_to_Y.sql files that cover the full range of upgrades.
        """

        self.log.warn("Starting {vers} upgrade from version {fr} to {to}.", vers=self.versionDescriptor, fr=fromVersion, to=toVersion)

        # Scan for all possible upgrade files - returned sorted
        files = self.scanForUpgradeFiles(dialect)

        # Determine upgrade sequence and run each upgrade
        upgrades = self.determineUpgradeSequence(fromVersion, toVersion, files, dialect)

        # Use one transaction for the entire set of upgrades
        try:
            for fp in upgrades:
                yield self.applyUpgrade(fp)
        except RuntimeError:
            self.log.error("Database {vers} upgrade failed using: {path}", vers=self.versionDescriptor, path=fp.basename())
            raise

        self.log.warn("{vers} upgraded from version {fr} to {to}.", vers=self.versionDescriptor.capitalize(), fr=fromVersion, to=toVersion)

    def getPathToUpgrades(self, dialect):
        """
        Return the path where appropriate upgrade files can be found.
        """
        raise NotImplementedError

    def scanForUpgradeFiles(self, dialect):
        """
        Scan for upgrade files with the require name.
        """

        fp = self.getPathToUpgrades(dialect)
        upgrades = []
        regex = re.compile("%supgrade_from_(\d+)_to_(\d+)%s" % (self.upgradeFilePrefix, self.upgradeFileSuffix,))
        for child in fp.globChildren("%supgrade_*%s" % (self.upgradeFilePrefix, self.upgradeFileSuffix,)):
            matched = regex.match(child.basename())
            if matched is not None:
                fromV = int(matched.group(1))
                toV = int(matched.group(2))
                upgrades.append((fromV, toV, child))

        upgrades.sort(key=lambda x: (x[0], x[1]))
        return upgrades

    def determineUpgradeSequence(self, fromVersion, toVersion, files, dialect):
        """
        Determine the upgrade_from_X_to_Y(.sql|.py) files that cover the full range of upgrades.
        Note that X and Y may not be consecutive, e.g., we might have an upgrade from 3 to 4,
        4 to 5, and 3 to 5 - the later because it is more efficient to jump over the intermediate
        step. As a result we will always try and pick the upgrade file that gives the biggest
        jump from one version to another at each step.
        """

        # Now find the path from the old version to the current one
        filesByFromVersion = {}
        for fromV, toV, fp in files:
            if fromV not in filesByFromVersion or filesByFromVersion[fromV][1] < toV:
                filesByFromVersion[fromV] = fromV, toV, fp

        upgrades = []
        nextVersion = fromVersion
        while nextVersion != toVersion:
            if nextVersion not in filesByFromVersion:
                msg = "Missing upgrade file from version %d with dialect %s" % (nextVersion, dialect,)
                self.log.error(msg)
                raise RuntimeError(msg)
            else:
                upgrades.append(filesByFromVersion[nextVersion][2])
                nextVersion = filesByFromVersion[nextVersion][1]

        return upgrades

    def applyUpgrade(self, fp):
        """
        Apply the supplied upgrade to the database. Always return an L{Deferred"
        """
        raise NotImplementedError
Beispiel #2
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##
from __future__ import print_function

from calendarserver.tools.cmdline import utilityMain, WorkerService
from argparse import ArgumentParser
from twext.python.log import Logger
from twisted.internet.defer import inlineCallbacks
from twext.who.idirectory import RecordType
import time

log = Logger()


class DisplayAPNSubscriptions(WorkerService):

    users = []

    def doWork(self):
        rootResource = self.rootResource()
        directory = rootResource.getDirectory()
        return displayAPNSubscriptions(self.store, directory, rootResource,
                                       self.users)


def main():
Beispiel #3
0
class DataStoreTransaction(object):
    """
    In-memory implementation of a data store transaction.
    """
    log = Logger()

    def __init__(self, dataStore, name):
        """
        Initialize a transaction; do not call this directly, instead call
        L{CalendarStore.newTransaction}.

        @param calendarStore: The store that created this transaction.

        @type calendarStore: L{CalendarStore}
        """
        self._dataStore = dataStore
        self._termination = None
        self._operations = []
        self._postCommitOperations = []
        self._postAbortOperations = []
        self._tracker = _CommitTracker(name)

    def store(self):
        return self._dataStore

    def addOperation(self, operation, name):
        self._operations.append(operation)
        self._tracker.info.append(name)

    def _terminate(self, mode):
        """
        Check to see if this transaction has already been terminated somehow,
        either via committing or aborting, and if not, note that it has been
        terminated.

        @param mode: The manner of the termination of this transaction.

        @type mode: C{str}

        @raise AlreadyFinishedError: This transaction has already been
            terminated.
        """
        if self._termination is not None:
            raise AlreadyFinishedError("already %s" % (self._termination, ))
        self._termination = mode
        self._tracker.done = True

    def abort(self):
        self._terminate("aborted")

        for operation in self._postAbortOperations:
            operation()

    def commit(self):
        self._terminate("committed")

        self.committed = True
        undos = []

        for operation in self._operations:
            try:
                undo = operation()
                if undo is not None:
                    undos.append(undo)
            except:
                self.log.debug("Undoing DataStoreTransaction")
                for undo in undos:
                    try:
                        undo()
                    except:
                        self.log.error("Cannot undo DataStoreTransaction")
                raise

        for operation in self._postCommitOperations:
            operation()

    def postCommit(self, operation):
        self._postCommitOperations.append(operation)

    def postAbort(self, operation):
        self._postAbortOperations.append(operation)
Beispiel #4
0
class NotificationsDatabase(AbstractSQLDatabase):
    log = Logger()

    db_basename = db_prefix + "notifications"
    schema_version = "1"
    db_type = "notifications"

    def __init__(self, resource):
        """
        @param resource: the L{CalDAVResource} resource for
            the notifications collection.)
        """
        self.resource = resource
        db_filename = os.path.join(self.resource.fp.path,
                                   NotificationsDatabase.db_basename)
        super(NotificationsDatabase, self).__init__(db_filename,
                                                    True,
                                                    autocommit=True)

        self.resource._txn.postCommit(self._db_close)
        self.resource._txn.postAbort(self._db_close)

    def allRecords(self):

        records = self._db_execute("select * from NOTIFICATIONS")
        return [
            self._makeRecord(row)
            for row in (records if records is not None else ())
        ]

    def recordForUID(self, uid):

        row = self._db_execute("select * from NOTIFICATIONS where UID = :1",
                               uid)
        return self._makeRecord(row[0]) if row else None

    def addOrUpdateRecord(self, record):

        self._db_execute(
            """insert or replace into NOTIFICATIONS (UID, NAME, TYPE)
            values (:1, :2, :3)
            """,
            record.uid,
            record.name,
            json.dumps(record.notificationtype),
        )

        self._db_execute(
            """
            insert or replace into REVISIONS (NAME, REVISION, DELETED)
            values (:1, :2, :3)
            """,
            record.name,
            self.bumpRevision(fast=True),
            'N',
        )

    def removeRecordForUID(self, uid):

        record = self.recordForUID(uid)
        self.removeRecordForName(record.name)

    def removeRecordForName(self, rname):

        self._db_execute("delete from NOTIFICATIONS where NAME = :1", rname)
        self._db_execute(
            """
            update REVISIONS SET REVISION = :1, DELETED = :2
            where NAME = :3
            """, self.bumpRevision(fast=True), 'Y', rname)

    def whatchanged(self, revision):

        results = [
            (name.encode("utf-8"), deleted)
            for name, deleted in self._db_execute(
                "select NAME, DELETED from REVISIONS where REVISION > :1",
                revision)
        ]
        results.sort(key=lambda x: x[1])

        changed = []
        deleted = []
        for name, wasdeleted in results:
            if name:
                if wasdeleted == 'Y':
                    if revision:
                        deleted.append(name)
                else:
                    changed.append(name)
            else:
                raise SyncTokenValidException

        return changed, deleted,

    def lastRevision(self):
        return self._db_value_for_sql("select REVISION from REVISION_SEQUENCE")

    def bumpRevision(self, fast=False):
        self._db_execute(
            """
            update REVISION_SEQUENCE set REVISION = REVISION + 1
            """, )
        self._db_commit()
        return self._db_value_for_sql(
            """
            select REVISION from REVISION_SEQUENCE
            """, )

    def _db_version(self):
        """
        @return: the schema version assigned to this index.
        """
        return NotificationsDatabase.schema_version

    def _db_type(self):
        """
        @return: the collection type assigned to this index.
        """
        return NotificationsDatabase.db_type

    def _db_init_data_tables(self, q):
        """
        Initialise the underlying database tables.
        @param q:           a database cursor to use.
        """
        #
        # NOTIFICATIONS table is the primary table
        #   UID: UID for this notification
        #   NAME: child resource name
        #   TYPE: type of notification
        #
        q.execute("""
            create table NOTIFICATIONS (
                UID            text unique,
                NAME           text unique,
                TYPE           text
            )
            """)

        q.execute("""
            create index UID on NOTIFICATIONS (UID)
            """)

        #
        # REVISIONS table tracks changes
        #   NAME: Last URI component (eg. <uid>.ics, RESOURCE primary key)
        #   REVISION: revision number
        #   WASDELETED: Y if revision deleted, N if added or changed
        #
        q.execute("""
            create table REVISION_SEQUENCE (
                REVISION        integer
            )
            """)
        q.execute("""
            insert into REVISION_SEQUENCE (REVISION) values (0)
            """)
        q.execute("""
            create table REVISIONS (
                NAME            text unique,
                REVISION        integer,
                DELETED         text(1)
            )
            """)
        q.execute("""
            create index REVISION on REVISIONS (REVISION)
            """)

    def _db_upgrade_data_tables(self, q, old_version):
        """
        Upgrade the data from an older version of the DB.
        """

        # Nothing to do as we have not changed the schema
        pass

    def _makeRecord(self, row):

        return NotificationRecord(
            *
            [str(item) if isinstance(item, unicode) else item for item in row])
Beispiel #5
0
class Memcacher(CachePoolUserMixIn):
    log = Logger()

    MEMCACHE_KEY_LIMIT = 250  # the memcached key length limit
    NAMESPACE_MAX_LENGTH = 32  # max size of namespace we will allow
    HASH_LENGTH = 32  # length of hash we will generate
    TRUNCATED_KEY_LENGTH = MEMCACHE_KEY_LIMIT - NAMESPACE_MAX_LENGTH - HASH_LENGTH - 2  # 2 accounts for delimiters
    MEMCACHE_VALUE_LIMIT = 1024 * 1024  # the memcached default value length limit

    # Translation table: all ctrls (0x00 - 0x1F) and space and 0x7F mapped to _
    keyNormalizeTranslateTable = string.maketrans(
        "".join([chr(i) for i in range(33)]) + chr(0x7F), "_" * 33 + "_")

    allowTestCache = False
    memoryCacheInstance = {
        True: None,
        False: None,
    }

    class memoryCacher():
        """
        A class implementing the memcache client API we care about but
        using a dict to store the results in memory. This can be used
        for caching on a single instance server, and for tests, where
        memcached may not be running.
        """
        def __init__(self, pickle=False):
            self._cache = {}  # (value, expireTime, check-and-set identifier)
            self._clock = 0
            self._pickle = pickle

        def _check_key(self, key):
            if not isinstance(key, str):
                raise ValueError("memcache keys must be str type")

        def _check_value(self, value):
            if not self._pickle and not isinstance(value, str):
                raise ValueError("memcache values must be str type")

        def add(self, key, value, expireTime=0):
            self._check_key(key)
            self._check_value(value)

            if len(key) > Memcacher.MEMCACHE_KEY_LIMIT or len(
                    str(value)) > Memcacher.MEMCACHE_VALUE_LIMIT:
                return succeed(False)
            if key not in self._cache:
                if not expireTime:
                    expireTime = 99999
                self._cache[key] = (value, self._clock + expireTime, 0)
                return succeed(True)
            else:
                return succeed(False)

        def set(self, key, value, expireTime=0):
            self._check_key(key)
            self._check_value(value)

            if len(key) > Memcacher.MEMCACHE_KEY_LIMIT or len(
                    str(value)) > Memcacher.MEMCACHE_VALUE_LIMIT:
                return succeed(False)
            if not expireTime:
                expireTime = 99999
            if key in self._cache:
                identifier = self._cache[key][2]
                identifier += 1
            else:
                identifier = 0
            self._cache[key] = (value, self._clock + expireTime, identifier)
            return succeed(True)

        def checkAndSet(self, key, value, cas, flags=0, expireTime=0):
            self._check_key(key)
            self._check_value(value)

            if len(key) > Memcacher.MEMCACHE_KEY_LIMIT or len(
                    str(value)) > Memcacher.MEMCACHE_VALUE_LIMIT:
                return succeed(False)
            if not expireTime:
                expireTime = 99999
            if key in self._cache:
                identifier = self._cache[key][2]
                if cas != str(identifier):
                    return succeed(False)
                identifier += 1
            else:
                return succeed(False)
            self._cache[key] = (value, self._clock + expireTime, identifier)
            return succeed(True)

        def get(self, key, withIdentifier=False):
            self._check_key(key)

            if len(key) > Memcacher.MEMCACHE_KEY_LIMIT:
                value, expires, identifier = (None, 0, "")
            else:
                value, expires, identifier = self._cache.get(
                    key, (None, 0, ""))
                if self._clock >= expires:
                    value = None
                    identifier = ""

            if withIdentifier:
                return succeed((0, value, str(identifier)))
            else:
                return succeed((
                    0,
                    value,
                ))

        def delete(self, key):
            self._check_key(key)

            if len(key) > Memcacher.MEMCACHE_KEY_LIMIT:
                return succeed(False)
            try:
                del self._cache[key]
                return succeed(True)
            except KeyError:
                return succeed(False)

        def incr(self, key, delta=1):
            self._check_key(key)

            if len(key) > Memcacher.MEMCACHE_KEY_LIMIT:
                return succeed(False)
            value = self._cache.get(key, None)
            if value is not None:
                value, expire, identifier = value
                try:
                    value = int(value)
                except ValueError:
                    value = None
                else:
                    value += delta
                    self._cache[key] = (
                        str(value),
                        expire,
                        identifier,
                    )
            return succeed(value)

        def decr(self, key, delta=1):
            self._check_key(key)

            if len(key) > Memcacher.MEMCACHE_KEY_LIMIT:
                return succeed(False)
            value = self._cache.get(key, None)
            if value is not None:
                value, expire, identifier = value
                try:
                    value = int(value)
                except ValueError:
                    value = None
                else:
                    value -= delta
                    if value < 0:
                        value = 0
                    self._cache[key] = (
                        str(value),
                        expire,
                        identifier,
                    )
            return succeed(value)

        def flushAll(self):
            self._cache = {}
            return succeed(True)

        def advanceClock(self, seconds):
            self._clock += seconds

    # TODO: an sqlite based cacher that can be used for multiple instance servers
    # in the absence of memcached. This is not ideal and we may want to not implement
    # this, but it is being documented for completeness.
    #
    # For now we implement a cacher that does not cache.
    class nullCacher():
        """
        A class implementing the memcache client API we care about but
        does not actually cache anything.
        """
        def add(self, key, value, expireTime=0):
            return succeed(True)

        def set(self, key, value, expireTime=0):
            return succeed(True)

        def checkAndSet(self, key, value, cas, flags=0, expireTime=0):
            return succeed(True)

        def get(self, key, withIdentifier=False):
            return succeed((
                0,
                None,
            ))

        def delete(self, key):
            return succeed(True)

        def incr(self, key, delta=1):
            return succeed(None)

        def decr(self, key, delta=1):
            return succeed(None)

        def flushAll(self):
            return succeed(True)

    def __init__(self,
                 namespace,
                 pickle=False,
                 no_invalidation=False,
                 key_normalization=True):
        """
        @param namespace: a unique namespace for this cache's keys
        @type namespace: C{str}
        @param pickle: if C{True} values will be pickled/unpickled when stored/read from the cache,
            if C{False} values will be stored directly (and therefore must be strings)
        @type pickle: C{bool}
        @param no_invalidation: if C{True} the cache is static - there will be no invalidations. This allows
            Memcacher to use the memoryCacher cache instead of nullCacher for the multi-instance case when memcached
            is not present,as there is no issue with caches in each instance getting out of sync. If C{False} the
            nullCacher will be used for the multi-instance case when memcached is not configured.
        @type no_invalidation: C{bool}
        @param key_normalization: if C{True} the key is assumed to possibly be longer than the Memcache key size and so additional
            work is done to truncate and append a hash.
        @type key_normalization: C{bool}
        """

        assert len(
            namespace
        ) <= Memcacher.NAMESPACE_MAX_LENGTH, "Memcacher namespace must be less than or equal to %s characters long" % (
            Memcacher.NAMESPACE_MAX_LENGTH, )
        self._memcacheProtocol = None
        self._cachePoolHandle = namespace
        self._namespace = namespace
        self._pickle = pickle
        self._noInvalidation = no_invalidation
        self._key_normalization = key_normalization

    def _getMemcacheProtocol(self):
        if self._memcacheProtocol is not None:
            return self._memcacheProtocol

        if config.Memcached.Pools.Default.ClientEnabled:
            self._memcacheProtocol = self.getCachePool()

        elif config.ProcessType == "Single" or self._noInvalidation or self.allowTestCache:
            # The memory cacher handles python types natively, but we need to treat non-str types as an error
            # if pickling is off, so we use two global memory cachers for each pickle state
            if Memcacher.memoryCacheInstance[self._pickle] is None:
                Memcacher.memoryCacheInstance[
                    self._pickle] = Memcacher.memoryCacher(self._pickle)
            self._memcacheProtocol = Memcacher.memoryCacheInstance[
                self._pickle]

        else:
            # NB no need to pickle the null cacher as it handles python types natively
            self._memcacheProtocol = Memcacher.nullCacher()
            self._pickle = False

        return self._memcacheProtocol

    def _normalizeKey(self, key):

        if isinstance(key, unicode):
            key = key.encode("utf-8")
        assert isinstance(key, str), "Key must be a str."

        if self._key_normalization:
            hash = hashlib.md5(key).hexdigest()
            key = key[:Memcacher.TRUNCATED_KEY_LENGTH]
            return "%s-%s" % (
                key.translate(Memcacher.keyNormalizeTranslateTable),
                hash,
            )
        else:
            return key

    def add(self, key, value, expireTime=0):

        proto = self._getMemcacheProtocol()

        my_value = value
        if self._pickle:
            my_value = cPickle.dumps(value)
        self.log.debug("Adding Cache Token for %r" % (key, ))
        return proto.add('%s:%s' % (self._namespace, self._normalizeKey(key)),
                         my_value,
                         expireTime=expireTime)

    def set(self, key, value, expireTime=0):

        proto = self._getMemcacheProtocol()

        my_value = value
        if self._pickle:
            my_value = cPickle.dumps(value)
        self.log.debug("Setting Cache Token for %r" % (key, ))
        return proto.set('%s:%s' % (self._namespace, self._normalizeKey(key)),
                         my_value,
                         expireTime=expireTime)

    def checkAndSet(self, key, value, cas, flags=0, expireTime=0):

        proto = self._getMemcacheProtocol()

        my_value = value
        if self._pickle:
            my_value = cPickle.dumps(value)
        self.log.debug("Setting Cache Token for %r" % (key, ))
        return proto.checkAndSet('%s:%s' %
                                 (self._namespace, self._normalizeKey(key)),
                                 my_value,
                                 cas,
                                 expireTime=expireTime)

    def get(self, key, withIdentifier=False):
        def _gotit(result, withIdentifier):
            if withIdentifier:
                _ignore_flags, identifier, value = result
            else:
                _ignore_flags, value = result
            if self._pickle and value is not None:
                value = cPickle.loads(value)
            if withIdentifier:
                value = (identifier, value)
            return value

        self.log.debug("Getting Cache Token for %r" % (key, ))
        d = self._getMemcacheProtocol().get(
            '%s:%s' % (self._namespace, self._normalizeKey(key)),
            withIdentifier=withIdentifier)
        d.addCallback(_gotit, withIdentifier)
        return d

    def delete(self, key):
        self.log.debug("Deleting Cache Token for %r" % (key, ))
        return self._getMemcacheProtocol().delete(
            '%s:%s' % (self._namespace, self._normalizeKey(key)))

    def incr(self, key, delta=1):
        self.log.debug("Incrementing Cache Token for %r" % (key, ))
        return self._getMemcacheProtocol().incr(
            '%s:%s' % (self._namespace, self._normalizeKey(key)), delta)

    def decr(self, key, delta=1):
        self.log.debug("Decrementing Cache Token for %r" % (key, ))
        return self._getMemcacheProtocol().incr(
            '%s:%s' % (self._namespace, self._normalizeKey(key)), delta)

    def flushAll(self):
        self.log.debug("Flushing All Cache Tokens")
        return self._getMemcacheProtocol().flushAll()

    @classmethod
    def reset(cls):
        """
        Reset the memory cachers
        """
        cls.memoryCacheInstance = {True: None, False: None}
Beispiel #6
0
class CalendarUserProxyPrincipalResource(CalDAVComplianceMixIn,
                                         PermissionsMixIn,
                                         DAVResourceWithChildrenMixin,
                                         DAVPrincipalResource):
    """
    Calendar user proxy principal resource.
    """
    log = Logger()

    def __init__(self, parent, proxyType):
        """
        @param parent: the parent of this resource.
        @param proxyType: a C{str} containing the name of the resource.
        """
        if self.isCollection():
            slash = "/"
        else:
            slash = ""

        url = joinURL(parent.principalURL(), proxyType) + slash

        super(CalendarUserProxyPrincipalResource, self).__init__()
        DAVResourceWithChildrenMixin.__init__(self)

        self.parent = parent
        self.proxyType = proxyType
        self._url = url

        # FIXME: if this is supposed to be public, it needs a better name:
        self.pcollection = self.parent.parent.parent

        # Principal UID is parent's GUID plus the proxy type; this we can easily
        # map back to a principal.
        self.uid = "%s#%s" % (self.parent.principalUID(), proxyType)
        self._alternate_urls = tuple(
            joinURL(url, proxyType) + slash for url in parent.alternateURIs()
            if url.startswith("/"))

    def __str__(self):
        return "%s [%s]" % (self.parent, self.proxyType)

    def _index(self):
        """
        Return the SQL database for this group principal.

        @return: the L{ProxyDB} for the principal collection.
        """
        return ProxyDBService

    def resourceType(self):
        if self.proxyType == "calendar-proxy-read":
            return davxml.ResourceType.calendarproxyread  # @UndefinedVariable
        elif self.proxyType == "calendar-proxy-write":
            return davxml.ResourceType.calendarproxywrite  # @UndefinedVariable
        elif self.proxyType == "calendar-proxy-read-for":
            return davxml.ResourceType.calendarproxyreadfor  # @UndefinedVariable
        elif self.proxyType == "calendar-proxy-write-for":
            return davxml.ResourceType.calendarproxywritefor  # @UndefinedVariable
        else:
            return super(CalendarUserProxyPrincipalResource,
                         self).resourceType()

    def isProxyType(self, read_write):
        if (read_write and self.proxyType == "calendar-proxy-write"
                or not read_write and self.proxyType == "calendar-proxy-read"):
            return True
        else:
            return False

    def isCollection(self):
        return True

    def etag(self):
        return succeed(None)

    def deadProperties(self):
        if not hasattr(self, "_dead_properties"):
            self._dead_properties = NonePropertyStore(self)
        return self._dead_properties

    def writeProperty(self, property, request):
        assert isinstance(property, davxml.WebDAVElement)

        if property.qname() == (dav_namespace, "group-member-set"):
            return self.setGroupMemberSet(property, request)

        return super(CalendarUserProxyPrincipalResource,
                     self).writeProperty(property, request)

    @inlineCallbacks
    def setGroupMemberSet(self, new_members, request):
        # FIXME: as defined right now it is not possible to specify a
        # calendar-user-proxy group as a member of any other group since the
        # directory service does not know how to lookup these special resource
        # UIDs.
        #
        # Really, c-u-p principals should be treated the same way as any other
        # principal, so they should be allowed as members of groups.
        #
        # This implementation now raises an exception for any principal it
        # cannot find.

        # Break out the list into a set of URIs.
        members = [str(h) for h in new_members.children]

        # Map the URIs to principals and a set of UIDs.
        principals = []
        newUIDs = set()
        for uri in members:
            principal = yield self.pcollection._principalForURI(uri)
            # Invalid principals MUST result in an error.
            if principal is None or principal.principalURL() != uri:
                raise HTTPError(
                    StatusResponse(
                        responsecode.BAD_REQUEST,
                        "Attempt to use a non-existent principal %s "
                        "as a group member of %s." % (
                            uri,
                            self.principalURL(),
                        )))
            principals.append(principal)
            newUIDs.add(principal.principalUID())

        # Get the old set of UIDs
        # oldUIDs = (yield self._index().getMembers(self.uid))
        oldPrincipals = yield self.groupMembers()
        oldUIDs = [p.principalUID() for p in oldPrincipals]

        # Change membership
        yield self.setGroupMemberSetPrincipals(principals)

        # Invalidate the primary principal's cache, and any principal's whose
        # membership status changed
        yield self.parent.cacheNotifier.changed()

        changedUIDs = newUIDs.symmetric_difference(oldUIDs)
        for uid in changedUIDs:
            principal = yield self.pcollection.principalForUID(uid)
            if principal:
                yield principal.cacheNotifier.changed()

        returnValue(True)

    @inlineCallbacks
    def setGroupMemberSetPrincipals(self, principals):

        # Find our pseudo-record
        record = yield self.parent.record.service.recordWithShortName(
            self._recordTypeFromProxyType(), self.parent.principalUID())
        # Set the members
        memberRecords = [p.record for p in principals]
        yield record.setMembers(memberRecords)

    ##
    # HTTP
    ##

    def htmlElement(self):
        """
        Customize HTML display of proxy groups.
        """
        return ProxyPrincipalElement(self)

    ##
    # DAV
    ##

    def displayName(self):
        return self.proxyType

    ##
    # ACL
    ##

    def alternateURIs(self):
        # FIXME: Add API to IDirectoryRecord for getting a record URI?
        return self._alternate_urls

    def principalURL(self):
        return self._url

    def principalUID(self):
        return self.uid

    def principalCollections(self):
        return self.parent.principalCollections()

    @inlineCallbacks
    def _expandMemberPrincipals(self,
                                uid=None,
                                relatives=None,
                                uids=None,
                                infinity=False):
        if uid is None:
            uid = self.principalUID()
        if relatives is None:
            relatives = set()
        if uids is None:
            uids = set()

        if uid not in uids:
            from twistedcaldav.directory.principal import DirectoryPrincipalResource
            uids.add(uid)
            principal = yield self.pcollection.principalForUID(uid)
            if isinstance(principal, CalendarUserProxyPrincipalResource):
                members = yield self._directGroupMembers()
                for member in members:
                    if member.principalUID() not in uids:
                        relatives.add(member)
                        if infinity:
                            yield self._expandMemberPrincipals(
                                member.principalUID(),
                                relatives,
                                uids,
                                infinity=infinity)
            elif isinstance(principal, DirectoryPrincipalResource):
                if infinity:
                    members = yield principal.expandedGroupMembers()
                else:
                    members = yield principal.groupMembers()
                relatives.update(members)

        returnValue(relatives)

    def _recordTypeFromProxyType(self):
        return {
            "calendar-proxy-read": DelegateRecordType.readDelegateGroup,
            "calendar-proxy-write": DelegateRecordType.writeDelegateGroup,
            "calendar-proxy-read-for": DelegateRecordType.readDelegatorGroup,
            "calendar-proxy-write-for": DelegateRecordType.writeDelegatorGroup,
        }.get(self.proxyType)

    @inlineCallbacks
    def _directGroupMembers(self):
        """
        Fault in the record representing the sub principal for this proxy type
        (either read-only or read-write), then fault in the direct members of
        that record.
        """
        memberPrincipals = []
        record = yield self.parent.record.service.recordWithShortName(
            self._recordTypeFromProxyType(), self.parent.principalUID())
        if record is not None:
            memberRecords = yield record.members()
            for record in memberRecords:
                if record is not None:
                    principal = yield self.pcollection.principalForRecord(
                        record)
                    if principal is not None:
                        if (principal.record.loginAllowed
                                or principal.record.recordType is
                                BaseRecordType.group):
                            memberPrincipals.append(principal)
        returnValue(memberPrincipals)

    def groupMembers(self):
        return self._expandMemberPrincipals()

    @inlineCallbacks
    def expandedGroupMembers(self):
        """
        Return the complete, flattened set of principals belonging to this
        group.
        """
        returnValue((yield self._expandMemberPrincipals(infinity=True)))

    def groupMemberships(self):
        # Unlikely to ever want to put a subprincipal into a group
        return succeed([])

    @inlineCallbacks
    def containsPrincipal(self, principal):
        """
        Uses proxyFor information to turn the "contains principal" question around;
        rather than expanding this principal's groups to see if the other principal
        is a member, ask the other principal if they are a proxy for this principal's
        parent resource, since this principal is a proxy principal.

        @param principal: The principal to check
        @type principal: L{DirectoryCalendarPrincipalResource}
        @return: True if principal is a proxy (of the correct type) of our parent
        @rtype: C{boolean}
        """
        readWrite = self.isProxyType(True)  # is read-write
        if principal and self.parent in (yield principal.proxyFor(readWrite)):
            returnValue(True)
        returnValue(False)
Beispiel #7
0
class ProxyDB(AbstractADBAPIDatabase):
    """
    A database to maintain calendar user proxy group memberships.

    SCHEMA:

    Group Database:

    ROW: GROUPNAME, MEMBER
    """
    log = Logger()

    schema_version = "4"
    schema_type = "CALENDARUSERPROXY"

    class ProxyDBMemcacher(Memcacher):
        def __init__(self, namespace):
            super(ProxyDB.ProxyDBMemcacher, self).__init__(
                namespace,
                key_normalization=config.Memcached.ProxyDBKeyNormalization)

        def setMembers(self, guid, members):
            return self.set("members:%s" % (str(guid), ),
                            str(",".join(members)))

        def setMemberships(self, guid, memberships):
            return self.set("memberships:%s" % (str(guid), ),
                            str(",".join(memberships)))

        def getMembers(self, guid):
            def _value(value):
                if value:
                    return set(value.split(","))
                elif value is None:
                    return None
                else:
                    return set()

            d = self.get("members:%s" % (str(guid), ))
            d.addCallback(_value)
            return d

        def getMemberships(self, guid):
            def _value(value):
                if value:
                    return set(value.split(","))
                elif value is None:
                    return None
                else:
                    return set()

            d = self.get("memberships:%s" % (str(guid), ))
            d.addCallback(_value)
            return d

        def deleteMember(self, guid):
            return self.delete("members:%s" % (str(guid), ))

        def deleteMembership(self, guid):
            return self.delete("memberships:%s" % (str(guid), ))

    def __init__(self, dbID, dbapiName, dbapiArgs, **kwargs):
        AbstractADBAPIDatabase.__init__(self, dbID, dbapiName, dbapiArgs, True,
                                        **kwargs)

        self._memcacher = ProxyDB.ProxyDBMemcacher("ProxyDB")

    @inlineCallbacks
    def setGroupMembers(self, principalUID, members):
        """
        Add a group membership record.

        @param principalUID: the UID of the group principal to add.
        @param members: a list UIDs of principals that are members of this group.
        """

        # Get current members before we change them
        current_members = yield self.getMembers(principalUID)
        if current_members is None:
            current_members = ()
        current_members = set(current_members)

        # Find changes
        update_members = set(members)
        remove_members = current_members.difference(update_members)
        add_members = update_members.difference(current_members)

        yield self.changeGroupMembersInDatabase(principalUID, add_members,
                                                remove_members)

        # Update cache
        for member in itertools.chain(
                remove_members,
                add_members,
        ):
            yield self._memcacher.deleteMembership(member)
        yield self._memcacher.deleteMember(principalUID)

    @inlineCallbacks
    def setGroupMembersInDatabase(self, principalUID, members):
        """
        A blocking call to add a group membership record in the database.

        @param principalUID: the UID of the group principal to add.
        @param members: a list UIDs of principals that are members of this group.
        """
        # Remove what is there, then add it back.
        yield self._delete_from_db(principalUID)
        yield self._add_to_db(principalUID, members)

    @inlineCallbacks
    def changeGroupMembersInDatabase(self, principalUID, addMembers,
                                     removeMembers):
        """
        A blocking call to add a group membership record in the database.

        @param principalUID: the UID of the group principal to add.
        @param addMembers: a list UIDs of principals to be added as members of this group.
        @param removeMembers: a list UIDs of principals to be removed as members of this group.
        """
        # Remove what is there, then add it back.
        for member in removeMembers:
            yield self._delete_from_db_one(principalUID, member)
        for member in addMembers:
            yield self._add_to_db_one(principalUID, member)

    @inlineCallbacks
    def removeGroup(self, principalUID):
        """
        Remove a group membership record.

        @param principalUID: the UID of the group principal to remove.
        """

        # Need to get the members before we do the delete
        members = yield self.getMembers(principalUID)

        yield self._delete_from_db(principalUID)

        # Update cache
        if members:
            for member in members:
                yield self._memcacher.deleteMembership(member)
            yield self._memcacher.deleteMember(principalUID)

    def getMembers(self, principalUID):
        """
        Return the list of group member UIDs for the specified principal.

        @return: a deferred returning a C{set} of members.
        """
        def gotCachedMembers(members):
            if members is not None:
                return members

            # Cache miss; compute members and update cache
            def gotMembersFromDB(dbmembers):
                members = set([row[0].encode("utf-8") for row in dbmembers])
                d = self._memcacher.setMembers(principalUID, members)
                d.addCallback(lambda _: members)
                return d

            d = self.query("select MEMBER from GROUPS where GROUPNAME = :1",
                           (principalUID.decode("utf-8"), ))
            d.addCallback(gotMembersFromDB)
            return d

        d = self._memcacher.getMembers(principalUID)
        d.addCallback(gotCachedMembers)
        return d

    def getMemberships(self, principalUID):
        """
        Return the list of group principal UIDs the specified principal is a member of.

        @return: a deferred returning a C{set} of memberships.
        """
        def gotCachedMemberships(memberships):
            if memberships is not None:
                return memberships

            # Cache miss; compute memberships and update cache
            def gotMembershipsFromDB(dbmemberships):
                memberships = set(
                    [row[0].encode("utf-8") for row in dbmemberships])
                d = self._memcacher.setMemberships(principalUID, memberships)
                d.addCallback(lambda _: memberships)
                return d

            d = self.query("select GROUPNAME from GROUPS where MEMBER = :1",
                           (principalUID.decode("utf-8"), ))
            d.addCallback(gotMembershipsFromDB)
            return d

        d = self._memcacher.getMemberships(principalUID)
        d.addCallback(gotCachedMemberships)
        return d

    @inlineCallbacks
    def _add_to_db(self, principalUID, members):
        """
        Insert the specified entry into the database.

        @param principalUID: the UID of the group principal to add.
        @param members: a list of UIDs or principals that are members of this group.
        """
        for member in members:
            yield self.execute(
                """
                insert into GROUPS (GROUPNAME, MEMBER)
                values (:1, :2)
                """, (
                    principalUID.decode("utf-8"),
                    member,
                ))

    def _add_to_db_one(self, principalUID, memberUID):
        """
        Insert the specified entry into the database.

        @param principalUID: the UID of the group principal to add.
        @param memberUID: the UID of the principal that is being added as a member of this group.
        """
        return self.execute(
            """
            insert into GROUPS (GROUPNAME, MEMBER)
            values (:1, :2)
            """, (
                principalUID.decode("utf-8"),
                memberUID.decode("utf-8"),
            ))

    def _delete_from_db(self, principalUID):
        """
        Deletes the specified entry from the database.

        @param principalUID: the UID of the group principal to remove.
        """
        return self.execute("delete from GROUPS where GROUPNAME = :1",
                            (principalUID.decode("utf-8"), ))

    def _delete_from_db_one(self, principalUID, memberUID):
        """
        Deletes the specified entry from the database.

        @param principalUID: the UID of the group principal to remove.
        @param memberUID: the UID of the principal that is being removed as a member of this group.
        """
        return self.execute(
            "delete from GROUPS where GROUPNAME = :1 and MEMBER = :2", (
                principalUID.decode("utf-8"),
                memberUID.decode("utf-8"),
            ))

    def _delete_from_db_member(self, principalUID):
        """
        Deletes the specified member entry from the database.

        @param principalUID: the UID of the member principal to remove.
        """
        return self.execute("delete from GROUPS where MEMBER = :1",
                            (principalUID.decode("utf-8"), ))

    def _db_version(self):
        """
        @return: the schema version assigned to this index.
        """
        return ProxyDB.schema_version

    def _db_type(self):
        """
        @return: the collection type assigned to this index.
        """
        return ProxyDB.schema_type

    @inlineCallbacks
    def _db_init_data_tables(self):
        """
        Initialise the underlying database tables.
        @param q:           a database cursor to use.
        """

        #
        # GROUPS table
        #
        yield self._create_table(
            "GROUPS",
            (
                ("GROUPNAME", "text"),
                ("MEMBER", "text"),
            ),
            ifnotexists=True,
        )

        yield self._create_index(
            "GROUPNAMES",
            "GROUPS",
            ("GROUPNAME", ),
            ifnotexists=True,
        )
        yield self._create_index(
            "MEMBERS",
            "GROUPS",
            ("MEMBER", ),
            ifnotexists=True,
        )

    @inlineCallbacks
    def open(self):
        """
        Open the database, normalizing all UUIDs in the process if necessary.
        """
        result = yield super(ProxyDB, self).open()
        yield self._maybeNormalizeUUIDs()
        returnValue(result)

    @inlineCallbacks
    def _maybeNormalizeUUIDs(self):
        """
        Normalize the UUIDs in the proxy database so they correspond to the
        normalized UUIDs in the main calendar database.
        """
        alreadyDone = yield self._db_value_for_sql(
            "select VALUE from CALDAV where KEY = 'UUIDS_NORMALIZED'")
        if alreadyDone is None:
            for (groupname, member) in ((yield self._db_all_values_for_sql(
                    "select GROUPNAME, MEMBER from GROUPS"))):
                grouplist = groupname.split("#")
                grouplist[0] = normalizeUUID(grouplist[0])
                newGroupName = "#".join(grouplist)
                newMemberName = normalizeUUID(member)
                if newGroupName != groupname or newMemberName != member:
                    yield self._db_execute(
                        """
                        update GROUPS set GROUPNAME = :1, MEMBER = :2
                        where GROUPNAME = :3 and MEMBER = :4
                    """, [newGroupName, newMemberName, groupname, member])
            yield self._db_execute("""
                insert or ignore into CALDAV (KEY, VALUE)
                values ('UUIDS_NORMALIZED', 'YES')
                """)

    @inlineCallbacks
    def _db_upgrade_data_tables(self, old_version):
        """
        Upgrade the data from an older version of the DB.
        @param old_version: existing DB's version number
        @type old_version: str
        """

        # Add index if old version is less than "4"
        if int(old_version) < 4:
            yield self._create_index(
                "GROUPNAMES",
                "GROUPS",
                ("GROUPNAME", ),
                ifnotexists=True,
            )
            yield self._create_index(
                "MEMBERS",
                "GROUPS",
                ("MEMBER", ),
                ifnotexists=True,
            )

    def _db_empty_data_tables(self):
        """
        Empty the underlying database tables.
        @param q:           a database cursor to use.
        """

        #
        # GROUPS table
        #
        return self._db_execute("delete from GROUPS")

    @inlineCallbacks
    def clean(self):

        if not self.initialized:
            yield self.open()

        for group in [
                row[0]
                for row in (yield self.query("select GROUPNAME from GROUPS"))
        ]:
            self.removeGroup(group)

        yield super(ProxyDB, self).clean()

    @inlineCallbacks
    def getAllMembers(self):
        """
        Retrieve all members that have been directly delegated to
        """
        returnValue([
            row[0]
            for row in (yield self.query("select DISTINCT MEMBER from GROUPS"))
        ])
class LinkResource(CalDAVComplianceMixIn, WrapperResource):
    """
    This is similar to a WrapperResource except that we locate our resource dynamically. We need to deal with the
    case of a missing underlying resource (broken link) as indicated by self._linkedResource being None.
    """
    log = Logger()

    def __init__(self, parent, link_url):
        self.parent = parent
        self.linkURL = link_url
        self.loopDetect = set()
        super(LinkResource, self).__init__(self.parent.principalCollections())

    @inlineCallbacks
    def linkedResource(self, request):

        if not hasattr(self, "_linkedResource"):
            if self.linkURL in self.loopDetect:
                raise HTTPError(
                    StatusResponse(
                        responsecode.LOOP_DETECTED,
                        "Recursive link target: %s" % (self.linkURL, )))
            else:
                self.loopDetect.add(self.linkURL)
            self._linkedResource = (yield request.locateResource(self.linkURL))
            self.loopDetect.remove(self.linkURL)

        if self._linkedResource is None:
            raise HTTPError(
                StatusResponse(responsecode.NOT_FOUND,
                               "Missing link target: %s" % (self.linkURL, )))

        returnValue(self._linkedResource)

    def isCollection(self):
        return True if hasattr(self, "_linkedResource") else False

    def resourceType(self):
        return self._linkedResource.resourceType() if hasattr(
            self, "_linkedResource") else davxml.ResourceType.link

    def locateChild(self, request, segments):
        def _defer(result):
            if result is None:
                return (self, server.StopTraversal)
            else:
                return (result, segments)

        d = self.linkedResource(request)
        d.addCallback(_defer)
        return d

    @inlineCallbacks
    def renderHTTP(self, request):
        linked_to = (yield self.linkedResource(request))
        if linked_to:
            returnValue(linked_to)
        else:
            returnValue(
                http.StatusResponse(
                    responsecode.OK, "Link resource with missing target: %s" %
                    (self.linkURL, )))

    def getChild(self, name):
        return self._linkedResource.getChild(name) if hasattr(
            self, "_linkedResource") else None

    @inlineCallbacks
    def hasProperty(self, property, request):
        hosted = (yield self.linkedResource(request))
        result = (yield hosted.hasProperty(property,
                                           request)) if hosted else False
        returnValue(result)

    @inlineCallbacks
    def readProperty(self, property, request):
        hosted = (yield self.linkedResource(request))
        result = (yield hosted.readProperty(property,
                                            request)) if hosted else None
        returnValue(result)

    @inlineCallbacks
    def writeProperty(self, property, request):
        hosted = (yield self.linkedResource(request))
        result = (yield hosted.writeProperty(property,
                                             request)) if hosted else None
        returnValue(result)
Beispiel #9
0
class AbstractPropertyStore(DictMixin, object):
    """
    Base property store.
    """
    log = Logger()

    implements(IPropertyStore)

    _defaultShadowableKeys = frozenset()
    _defaultProxyOverrideKeys = frozenset()
    _defaultGlobalKeys = frozenset((
        PropertyName.fromElement(davxml.ACL),
        PropertyName.fromElement(davxml.ResourceID),
        PropertyName.fromElement(davxml.ResourceType),
        PropertyName.fromElement(davxml.GETContentType),
        PropertyName.fromElement(TwistedGETContentMD5),
        PropertyName.fromElement(TwistedQuotaRootProperty),
    ))

    def __init__(self, defaultUser, shareeUser=None, proxyUser=None):
        """
        Instantiate the property store for a user. The default is the default user
        (owner) property to read in the case of global or shadowable properties.
        The sharee user is a user sharing the user to read for per-user properties.

        @param defaultUser: the default user uid
        @type defaultUser: C{str}

        @param shareeUser: the per user uid or None if the same as defaultUser
        @type shareeUser: C{str}

        @param proxyUser: the proxy uid or None if no proxy
        @type proxyUser: C{str}
        """

        assert (defaultUser is not None or shareeUser is not None)
        self._defaultUser = shareeUser if defaultUser is None else defaultUser
        self._perUser = defaultUser if shareeUser is None else shareeUser
        self._proxyUser = self._perUser if proxyUser is None else proxyUser
        self._shadowableKeys = set(
            AbstractPropertyStore._defaultShadowableKeys)
        self._proxyOverrideKeys = set(
            AbstractPropertyStore._defaultProxyOverrideKeys)
        self._globalKeys = set(AbstractPropertyStore._defaultGlobalKeys)

    def __str__(self):
        return "<%s>" % (self.__class__.__name__)

    def _setDefaultUserUID(self, uid):
        self._defaultUser = uid

    def _setPerUserUID(self, uid):
        self._perUser = uid

    def _setProxyUID(self, uid):
        self._proxyUser = uid

    def setSpecialProperties(self, shadowableKeys, globalKeys,
                             proxyOverrideKeys):
        self._shadowableKeys.update(shadowableKeys)
        self._proxyOverrideKeys.update(proxyOverrideKeys)
        self._globalKeys.update(globalKeys)

    #
    # Subclasses must override these
    #

    def _getitem_uid(self, key, uid):
        raise NotImplementedError()

    def _setitem_uid(self, key, value, uid):
        raise NotImplementedError()

    def _delitem_uid(self, key, uid):
        raise NotImplementedError()

    def _keys_uid(self, uid):
        raise NotImplementedError()

    def _removeResource(self):
        raise NotImplementedError()

    def flush(self):
        raise NotImplementedError()

    def abort(self):
        raise NotImplementedError()

    #
    # Required UserDict implementations
    #

    def __getitem__(self, key):
        # Return proxy value if it exists, else fall through to normal logic
        if self._proxyUser != self._perUser and self.isProxyOverrideProperty(
                key):
            try:
                return self._getitem_uid(key, self._proxyUser)
            except KeyError:
                pass

        # Handle per-user behavior
        if self.isShadowableProperty(key):
            try:
                result = self._getitem_uid(key, self._perUser)
            except KeyError:
                result = self._getitem_uid(key, self._defaultUser)
            return result
        elif self.isGlobalProperty(key):
            return self._getitem_uid(key, self._defaultUser)
        else:
            return self._getitem_uid(key, self._perUser)

    def __setitem__(self, key, value):
        # Handle per-user behavior
        if self.isGlobalProperty(key):
            return self._setitem_uid(key, value, self._defaultUser)
        # Handle proxy behavior
        elif self._proxyUser != self._perUser and self.isProxyOverrideProperty(
                key):
            return self._setitem_uid(key, value, self._proxyUser)
        # Remainder is per user
        else:
            return self._setitem_uid(key, value, self._perUser)

    def __delitem__(self, key):
        # Delete proxy value if it exists, else fall through to normal logic
        if self._proxyUser != self._perUser and self.isProxyOverrideProperty(
                key):
            try:
                self._delitem_uid(key, self._proxyUser)
                return
            except KeyError:
                pass

        # Handle per-user behavior
        if self.isShadowableProperty(key):
            try:
                self._delitem_uid(key, self._perUser)
            except KeyError:
                # It is OK for shadowable delete to fail
                pass
        elif self.isGlobalProperty(key):
            self._delitem_uid(key, self._defaultUser)
        else:
            self._delitem_uid(key, self._perUser)

    def keys(self):

        userkeys = list(self._keys_uid(self._perUser))
        if self._defaultUser != self._perUser:
            defaultkeys = self._keys_uid(self._defaultUser)
            for key in defaultkeys:
                if self.isShadowableProperty(key) and key not in userkeys:
                    userkeys.append(key)
        return tuple(userkeys)

    def update(self, other):
        # FIXME: direct tests.
        # FIXME: support positional signature (although since strings aren't
        # valid, it should just raise an error.
        for key in other:
            self[key] = other[key]

    def isShadowableProperty(self, key):
        return key in self._shadowableKeys

    def isProxyOverrideProperty(self, key):
        return key in self._proxyOverrideKeys

    def isGlobalProperty(self, key):
        return key in self._globalKeys

    def copyAllProperties(self, other):
        """
        Copy all the properties from another store into this one. This needs to be done
        independently of the UID. Each underlying store will need to implement this.
        """
        pass
class MailGatewayTokensDatabase(AbstractSQLDatabase):
    """
    A database to maintain "plus-address" tokens for IMIP requests.

    SCHEMA:

    Token Database:

    ROW: TOKEN, ORGANIZER, ATTENDEE, ICALUID, DATESTAMP
    """
    log = Logger()

    dbType = "MAILGATEWAYTOKENS"
    dbFilename = "mailgatewaytokens.sqlite"
    dbFormatVersion = "1"

    def __init__(self, path):
        if path != ":memory:":
            path = os.path.join(path, MailGatewayTokensDatabase.dbFilename)
        super(MailGatewayTokensDatabase, self).__init__(path, True)

    def createToken(self, organizer, attendee, icaluid, token=None):
        if token is None:
            token = str(uuid.uuid4())
        self._db_execute(
            """
            insert into TOKENS (TOKEN, ORGANIZER, ATTENDEE, ICALUID, DATESTAMP)
            values (:1, :2, :3, :4, :5)
            """, token, organizer, attendee, icaluid, datetime.date.today())
        self._db_commit()
        return token

    def lookupByToken(self, token):
        results = list(
            self._db_execute(
                """
                select ORGANIZER, ATTENDEE, ICALUID from TOKENS
                where TOKEN = :1
                """, token))

        if len(results) != 1:
            return None

        return results[0]

    def getToken(self, organizer, attendee, icaluid):
        token = self._db_value_for_sql(
            """
            select TOKEN from TOKENS
            where ORGANIZER = :1 and ATTENDEE = :2 and ICALUID = :3
            """, organizer, attendee, icaluid)
        if token is not None:
            # update the datestamp on the token to keep it from being purged
            self._db_execute(
                """
                update TOKENS set DATESTAMP = :1 WHERE TOKEN = :2
                """, datetime.date.today(), token)
            return str(token)
        else:
            return None

    def getAllTokens(self):
        results = list(
            self._db_execute("""
                select TOKEN, ORGANIZER, ATTENDEE, ICALUID from TOKENS
                """))
        return results

    def deleteToken(self, token):
        self._db_execute(
            """
            delete from TOKENS where TOKEN = :1
            """, token)
        self._db_commit()

    def purgeOldTokens(self, before):
        self._db_execute(
            """
            delete from TOKENS where DATESTAMP < :1
            """, before)
        self._db_commit()

    def lowercase(self):
        """
        Lowercase mailto: addresses (and uppercase urn:uuid: addresses!) so
        they can be located via normalized names.
        """
        rows = self._db_execute("""
            select ORGANIZER, ATTENDEE from TOKENS
            """)
        for row in rows:
            organizer = row[0]
            attendee = row[1]
            if organizer.lower().startswith("mailto:"):
                self._db_execute(
                    """
                    update TOKENS set ORGANIZER = :1 WHERE ORGANIZER = :2
                    """, organizer.lower(), organizer)
            else:
                from txdav.base.datastore.util import normalizeUUIDOrNot
                self._db_execute(
                    """
                    update TOKENS set ORGANIZER = :1 WHERE ORGANIZER = :2
                    """, normalizeUUIDOrNot(organizer), organizer)
            # ATTENDEEs are always mailto: so unconditionally lower().
            self._db_execute(
                """
                update TOKENS set ATTENDEE = :1 WHERE ATTENDEE = :2
                """, attendee.lower(), attendee)
        self._db_commit()

    def _db_version(self):
        """
        @return: the schema version assigned to this index.
        """
        return MailGatewayTokensDatabase.dbFormatVersion

    def _db_type(self):
        """
        @return: the collection type assigned to this index.
        """
        return MailGatewayTokensDatabase.dbType

    def _db_init_data_tables(self, q):
        """
        Initialise the underlying database tables.
        @param q:           a database cursor to use.
        """

        #
        # TOKENS table
        #
        q.execute("""
            create table TOKENS (
                TOKEN       text,
                ORGANIZER   text,
                ATTENDEE    text,
                ICALUID     text,
                DATESTAMP   date
            )
            """)
        q.execute("""
            create index TOKENSINDEX on TOKENS (TOKEN)
            """)

    def _db_upgrade_data_tables(self, q, old_version):
        """
        Upgrade the data from an older version of the DB.
        @param q: a database cursor to use.
        @param old_version: existing DB's version number
        @type old_version: str
        """
        pass
Beispiel #11
0
class MemcachedUIDReserver(CachePoolUserMixIn):
    log = Logger()

    def __init__(self, index, cachePool=None):
        self.index = index
        self._cachePool = cachePool

    def _key(self, uid):
        return 'reservation:%s' % (hashlib.md5(
            '%s:%s' % (uid, self.index.resource.fp.path)).hexdigest())

    def reserveUID(self, uid):
        uid = uid.encode('utf-8')
        self.log.debug(
            "Reserving UID {uid} @ {path}",
            uid=uid,
            path=self.index.resource.fp.path,
        )

        def _handleFalse(result):
            if result is False:
                raise ReservationError(
                    "UID %s already reserved for calendar collection %s." %
                    (uid, self.index.resource))

        d = self.getCachePool().add(self._key(uid),
                                    'reserved',
                                    expireTime=config.UIDReservationTimeOut)
        d.addCallback(_handleFalse)
        return d

    def unreserveUID(self, uid):
        uid = uid.encode('utf-8')
        self.log.debug(
            "Unreserving UID {uid} @ {path}",
            uid=uid,
            path=self.index.resource.fp.path,
        )

        def _handleFalse(result):
            if result is False:
                raise ReservationError(
                    "UID %s is not reserved for calendar collection %s." %
                    (uid, self.index.resource))

        d = self.getCachePool().delete(self._key(uid))
        d.addCallback(_handleFalse)
        return d

    def isReservedUID(self, uid):
        uid = uid.encode('utf-8')
        self.log.debug(
            "Is reserved UID {uid} @ {path}",
            uid=uid,
            path=self.index.resource.fp.path,
        )

        def _checkValue((flags, value)):
            if value is None:
                return False
            else:
                return True

        d = self.getCachePool().get(self._key(uid))
        d.addCallback(_checkValue)
        return d
Beispiel #12
0
class AbstractCalendarIndex(AbstractSQLDatabase):
    """
    Calendar collection index abstract base class that defines the apis for the index.
    This will be subclassed for the two types of index behaviour we need: one for
    regular calendar collections, one for schedule calendar collections.
    """
    log = Logger()

    def __init__(self, resource):
        """
        @param resource: the L{CalDAVResource} resource to
            index. C{resource} must be a calendar collection (ie.
            C{resource.isPseudoCalendarCollection()} returns C{True}.)
        """
        self.resource = resource
        db_filename = self.resource.fp.child(db_basename).path
        super(AbstractCalendarIndex, self).__init__(db_filename, False)

        self.resource._txn.postCommit(self._db_close)
        self.resource._txn.postAbort(self._db_close)

    def create(self):
        """
        Create the index and initialize it.
        """
        self._db()

    def reserveUID(self, uid):
        """
        Reserve a UID for this index's resource.
        @param uid: the UID to reserve
        @raise ReservationError: if C{uid} is already reserved
        """
        raise NotImplementedError

    def unreserveUID(self, uid):
        """
        Unreserve a UID for this index's resource.
        @param uid: the UID to reserve
        @raise ReservationError: if C{uid} is not reserved
        """
        raise NotImplementedError

    def isReservedUID(self, uid):
        """
        Check to see whether a UID is reserved.
        @param uid: the UID to check
        @return: True if C{uid} is reserved, False otherwise.
        """
        raise NotImplementedError

    def isAllowedUID(self, uid, *names):
        """
        Checks to see whether to allow an operation with adds the the specified
        UID is allowed to the index.  Specifically, the operation may not
        violate the constraint that UIDs must be unique, and the UID must not
        be reserved.
        @param uid: the UID to check
        @param names: the names of resources being replaced or deleted by the
            operation; UIDs associated with these resources are not checked.
        @return: True if the UID is not in the index and is not reserved,
            False otherwise.
        """
        raise NotImplementedError

    def resourceNamesForUID(self, uid):
        """
        Looks up the names of the resources with the given UID.
        @param uid: the UID of the resources to look up.
        @return: a list of resource names
        """
        names = self._db_values_for_sql(
            "select NAME from RESOURCE where UID = :1", uid)

        #
        # Check that each name exists as a child of self.resource.  If not, the
        # resource record is stale.
        #
        resources = []
        for name in names:
            name_utf8 = name.encode("utf-8")
            if name is not None and self.resource.getChild(name_utf8) is None:
                # Clean up
                log.error(
                    "Stale resource record found for child {name} with UID {uid} in {rsrc!r}",
                    name=name,
                    uid=uid,
                    resource=self.resource)
                self._delete_from_db(name, uid, False)
                self._db_commit()
            else:
                resources.append(name_utf8)

        return resources

    def resourceNameForUID(self, uid):
        """
        Looks up the name of the resource with the given UID.
        @param uid: the UID of the resource to look up.
        @return: If the resource is found, its name; C{None} otherwise.
        """
        result = None

        for name in self.resourceNamesForUID(uid):
            assert result is None, "More than one resource with UID %s in calendar collection %r" % (
                uid, self)
            result = name

        return result

    def resourceUIDForName(self, name):
        """
        Looks up the UID of the resource with the given name.
        @param name: the name of the resource to look up.
        @return: If the resource is found, the UID of the resource; C{None}
            otherwise.
        """
        uid = self._db_value_for_sql(
            "select UID from RESOURCE where NAME = :1", name)

        return uid

    def componentTypeCounts(self):
        """
        Count each type of component.
        """
        return self._db_execute(
            "select TYPE, COUNT(TYPE) from RESOURCE group by TYPE")

    def addResource(self, name, calendar, fast=False, reCreate=False):
        """
        Adding or updating an existing resource.
        To check for an update we attempt to get an existing UID
        for the resource name. If present, then the index entries for
        that UID are removed. After that the new index entries are added.
        @param name: the name of the resource to add.
        @param calendar: a L{Calendar} object representing the resource
            contents.
        @param fast: if C{True} do not do commit, otherwise do commit.
        """
        oldUID = self.resourceUIDForName(name)
        if oldUID is not None:
            self._delete_from_db(name, oldUID, False)
        self._add_to_db(name, calendar, reCreate=reCreate)
        if not fast:
            self._db_commit()

    def deleteResource(self, name):
        """
        Remove this resource from the index.
        @param name: the name of the resource to add.
        @param uid: the UID of the calendar component in the resource.
        """
        uid = self.resourceUIDForName(name)
        if uid is not None:
            self._delete_from_db(name, uid)
            self._db_commit()

    def resourceExists(self, name):
        """
        Determines whether the specified resource name exists in the index.
        @param name: the name of the resource to test
        @return: True if the resource exists, False if not
        """
        uid = self._db_value_for_sql(
            "select UID from RESOURCE where NAME = :1", name)
        return uid is not None

    def resourcesExist(self, names):
        """
        Determines whether the specified resource name exists in the index.
        @param names: a C{list} containing the names of the resources to test
        @return: a C{list} of all names that exist
        """
        statement = "select NAME from RESOURCE where NAME in ("
        for ctr in (item[0] for item in enumerate(names)):
            if ctr != 0:
                statement += ", "
            statement += ":%s" % (ctr, )
        statement += ")"
        results = self._db_values_for_sql(statement, *names)
        return results

    def testAndUpdateIndex(self, minDate):
        # Find out if the index is expanded far enough
        names = self.notExpandedBeyond(minDate)
        # Actually expand recurrence max
        for name in names:
            self.log.info(
                "Search falls outside range of index for {name} {date}",
                name=name,
                date=minDate)
            self.reExpandResource(name, minDate)

    def whatchanged(self, revision):

        results = [
            (name.encode("utf-8"), deleted)
            for name, deleted in self._db_execute(
                "select NAME, DELETED from REVISIONS where REVISION > :1",
                revision)
        ]
        results.sort(key=lambda x: x[1])

        changed = []
        deleted = []
        invalid = []
        for name, wasdeleted in results:
            if name:
                if wasdeleted == 'Y':
                    if revision:
                        deleted.append(name)
                else:
                    changed.append(name)
            else:
                raise SyncTokenValidException

        return (changed, deleted, invalid)

    def lastRevision(self):
        return self._db_value_for_sql("select REVISION from REVISION_SEQUENCE")

    def bumpRevision(self, fast=False):
        self._db_execute(
            """
            update REVISION_SEQUENCE set REVISION = REVISION + 1
            """, )
        self._db_commit()
        return self._db_value_for_sql(
            """
            select REVISION from REVISION_SEQUENCE
            """, )

    def indexedSearch(self, filter, useruid="", fbtype=False):
        """
        Finds resources matching the given qualifiers.
        @param filter: the L{Filter} for the calendar-query to execute.
        @return: an iterable of tuples for each resource matching the
            given C{qualifiers}. The tuples are C{(name, uid, type)}, where
            C{name} is the resource name, C{uid} is the resource UID, and
            C{type} is the resource iCalendar component type.
        """

        # Make sure we have a proper Filter element and get the partial SQL
        # statement to use.
        if isinstance(filter, Filter):
            if fbtype:
                # Lookup the useruid - try the empty (default) one if needed
                dbuseruid = self._db_value_for_sql(
                    "select PERUSERID from PERUSER where USERUID == :1",
                    useruid,
                )
            else:
                dbuseruid = ""

            qualifiers = sqlcalendarquery(filter, None, dbuseruid, fbtype)
            if qualifiers is not None:
                # Determine how far we need to extend the current expansion of
                # events. If we have an open-ended time-range we will expand one
                # year past the start. That should catch bounded recurrences - unbounded
                # will have been indexed with an "infinite" value always included.
                maxDate, isStartDate = filter.getmaxtimerange()
                if maxDate:
                    maxDate = maxDate.duplicate()
                    maxDate.setDateOnly(True)
                    if isStartDate:
                        maxDate += Duration(days=365)
                    self.testAndUpdateIndex(maxDate)
            else:
                # We cannot handle this filter in an indexed search
                raise IndexedSearchException()

        else:
            qualifiers = None

        # Perform the search
        if qualifiers is None:
            rowiter = self._db_execute("select NAME, UID, TYPE from RESOURCE")
        else:
            if fbtype:
                # For a free-busy time-range query we return all instances
                rowiter = self._db_execute(
                    "select DISTINCT RESOURCE.NAME, RESOURCE.UID, RESOURCE.TYPE, RESOURCE.ORGANIZER, TIMESPAN.FLOAT, TIMESPAN.START, TIMESPAN.END, TIMESPAN.FBTYPE, TIMESPAN.TRANSPARENT, TRANSPARENCY.TRANSPARENT"
                    + qualifiers[0], *qualifiers[1])
            else:
                rowiter = self._db_execute(
                    "select DISTINCT RESOURCE.NAME, RESOURCE.UID, RESOURCE.TYPE"
                    + qualifiers[0], *qualifiers[1])

        # Check result for missing resources
        results = []
        for row in rowiter:
            name = row[0]
            if self.resource.getChild(name.encode("utf-8")):
                if fbtype:
                    row = list(row)
                    if row[9]:
                        row[8] = row[9]
                    del row[9]
                results.append(row)
            else:
                log.error(
                    "Calendar resource {name} is missing from {rsrc!r}. Removing from index.",
                    name=name,
                    rsrc=self.resource,
                )
                self.deleteResource(name)

        return results

    def bruteForceSearch(self):
        """
        List the whole index and tests for existence, updating the index
        @return: all resources in the index
        """
        # List all resources
        rowiter = self._db_execute("select NAME, UID, TYPE from RESOURCE")

        # Check result for missing resources:

        results = []
        for row in rowiter:
            name = row[0]
            if self.resource.getChild(name.encode("utf-8")):
                results.append(row)
            else:
                log.error(
                    "Calendar resource {name} is missing from {rsrc!r}. Removing from index.",
                    name=name,
                    rsrc=self.resource,
                )

        return results

    def _db_version(self):
        """
        @return: the schema version assigned to this index.
        """
        return schema_version

    def _add_to_db(self,
                   name,
                   calendar,
                   cursor=None,
                   expand_until=None,
                   reCreate=False):
        """
        Records the given calendar resource in the index with the given name.
        Resource names and UIDs must both be unique; only one resource name may
        be associated with any given UID and vice versa.
        NB This method does not commit the changes to the db - the caller
        MUST take care of that
        @param name: the name of the resource to add.
        @param calendar: a L{Calendar} object representing the resource
            contents.
        """
        raise NotImplementedError

    def _delete_from_db(self, name, uid, dorevision=True):
        """
        Deletes the specified entry from all dbs.
        @param name: the name of the resource to delete.
        @param uid: the uid of the resource to delete.
        """
        raise NotImplementedError