예제 #1
0
class TestPsqlConnectionHelper(unittest.TestCase):
    def setUp(self):
        self.conn_helper = PsqlConnectionHelper()

    def test_batch_execute_on_empty_iterable(self):
        # noinspection PyUnusedLocal
        def exec_hook(cur, batch):
            raise Exception('This line shouldn\'t be reached with an empty '
                            'iterable.')

        list(self.conn_helper.batch_execute(iter(()), exec_hook, 1))
예제 #2
0
파일: postgres.py 프로젝트: sanyarud/SMQTK
class PostgresKeyValueStore(KeyValueStore):
    """
    PostgreSQL-backed key-value storage.
    """
    class SqlTemplates(object):
        """
        Container for static PostgreSQL queries used by the containing class.
        """

        UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
            CREATE TABLE IF NOT EXISTS {table_name:s} (
              {key_col:s} BYTEA NOT NULL,
              {value_col:s} BYTEA NOT NULL,
              PRIMARY KEY ({key_col:s})
            );
        """)

        SELECT_TMPL = norm_psql_cmd_string("""
            SELECT {query:s} FROM {table_name:s};
        """)

        SELECT_LIKE_TMPL = norm_psql_cmd_string("""
            SELECT {query:s}
              FROM {table_name:s}
             WHERE {key_col:s} LIKE %(key_like)s
        """)

        SELECT_MANY_TMPL = norm_psql_cmd_string("""
            SELECT {query:s}
              FROM {table_name:s}
             WHERE {key_col:s} IN %(key_tuple)s
        """)

        UPSERT_TMPL = norm_psql_cmd_string("""
            WITH upsert AS (
              UPDATE {table_name:s}
                SET {value_col:s} = %(val)s
                WHERE {key_col:s} = %(key)s
                RETURNING *
              )
            INSERT INTO {table_name:s}
              ({key_col:s}, {value_col:s})
              SELECT %(key)s, %(val)s
                WHERE NOT EXISTS (SELECT * FROM upsert)
        """)

        DELETE_LIKE_TMPL = norm_psql_cmd_string("""
            DELETE FROM {table_name:s}
            WHERE {key_col:s} LIKE %(key_like)s 
        """)

        DELETE_ALL = norm_psql_cmd_string("""
            DELETE FROM {table_name:s}
        """)

    @classmethod
    def is_usable(cls):
        return psycopg2 is not None

    def __init__(self,
                 table_name="data_set",
                 key_col='key',
                 value_col='value',
                 db_name='postgres',
                 db_host=None,
                 db_port=None,
                 db_user=None,
                 db_pass=None,
                 batch_size=1000,
                 pickle_protocol=-1,
                 read_only=False,
                 create_table=True):
        """
        Initialize a PostgreSQL-backed data set instance.

        :param table_name: Name of the table to use.
        :type table_name: str

        :param key_col: Name of the column containing the UUID signatures.
        :type key_col: str

        :param value_col: Name of the table column that will contain
            serialized elements.
        :type value_col: str

        :param db_name: The name of the database to connect to.
        :type db_name: str

        :param db_host: Host address of the Postgres server. If None, we
            assume the server is on the local machine and use the UNIX socket.
            This might be a required field on Windows machines (not tested yet).
        :type db_host: str | None

        :param db_port: Port the Postgres server is exposed on. If None, we
            assume the default port (5423).
        :type db_port: int | None

        :param db_user: Postgres user to connect as. If None, postgres
            defaults to using the current accessing user account name on the
            operating system.
        :type db_user: str | None

        :param db_pass: Password for the user we're connecting as. This may be
            None if no password is to be used.
        :type db_pass: str | None

        :param batch_size: For queries that handle sending or
            receiving many queries at a time, batch queries based on this size.
            If this is None, then no batching occurs.

            The advantage of batching is that it reduces the memory impact for
            queries dealing with a very large number of elements (don't have to
            store the full query for all elements in RAM), but the transaction
            will be some amount slower due to splitting the query into multiple
            transactions.
        :type batch_size: int | None

        :param pickle_protocol: Pickling protocol to use. We will use -1 by
            default (latest version, probably binary).
        :type pickle_protocol: int

        :param read_only: Only allow read actions against this index.
            Modification actions will throw a ReadOnlyError exceptions.
        :type read_only: bool

        :param create_table: If this instance should try to create the storing
            table before actions are performed against it when not set to be
            read-only. If the configured user does not have sufficient
            permissions to create the table and it does not currently exist, an
            exception will be raised.
        :type create_table: bool

        """
        super(PostgresKeyValueStore, self).__init__()

        self._table_name = table_name
        self._key_col = key_col
        self._value_col = value_col

        self._batch_size = batch_size
        self._pickle_protocol = pickle_protocol
        self._read_only = bool(read_only)
        self._create_table = create_table

        # Checking parameters where necessary
        if self._batch_size is not None:
            self._batch_size = int(self._batch_size)
            assert self._batch_size > 0, \
                "A given batch size must be greater than 0 in size " \
                "(given: %d)." % self._batch_size
        assert -1 <= self._pickle_protocol <= 2, \
            ("Given pickle protocol is not in the known valid range [-1, 2]. "
             "Given: %s." % self._pickle_protocol)

        # helper structure for SQL operations.
        self._psql_helper = PsqlConnectionHelper(
            db_name,
            db_host,
            db_port,
            db_user,
            db_pass,
            itersize=batch_size,
            table_upsert_lock=PSQL_TABLE_CREATE_RLOCK,
        )

        # Only set table upsert if not read-only.
        if not self._read_only and self._create_table:
            # NOT read-only, so allow table upsert.
            self._psql_helper.set_table_upsert_sql(
                self.SqlTemplates.UPSERT_TABLE_TMPL.format(
                    table_name=self._table_name,
                    key_col=self._key_col,
                    value_col=self._value_col))

    @staticmethod
    def _py_to_bin(k):
        """
        Convert a python hashable value into psycopg2.Binary via pickle.

        :param k: Python object instance to be converted into a
            ``psycopg2.Binary`` instance via ``pickle`` serialization.
        :type k: object

        :return: ``psycopg2.Binary`` buffer instance to use for insertion into
            or query against a table.
        :rtype: psycopg2.Binary

        """
        return psycopg2.Binary(pickle.dumps(k))

    @staticmethod
    def _bin_to_py(b):
        """
        Un-"translate" psycopg2.Binary value (buffer) to a python type.

        :param b: ``psycopg2.Binary`` buffer instance as retrieved from a
            PostgreSQL query.
        :type b: psycopg2.Binary

        :return: Python object instance as loaded via pickle from the given
            ``psycopg2.Binary`` buffer.
        :rtype: object

        """
        return pickle.loads(bytes(b))

    def get_config(self):
        """
        Return a JSON-compliant dictionary that could be passed to this class's
        ``from_config`` method to produce an instance with identical
        configuration.

        In the common case, this involves naming the keys of the dictionary
        based on the initialization argument names as if it were to be passed
        to the constructor via dictionary expansion.

        :return: JSON type compliant configuration dictionary.
        :rtype: dict

        """
        return {
            "table_name": self._table_name,
            "key_col": self._key_col,
            "value_col": self._value_col,
            "db_name": self._psql_helper.db_name,
            "db_host": self._psql_helper.db_host,
            "db_port": self._psql_helper.db_port,
            "db_user": self._psql_helper.db_user,
            "db_pass": self._psql_helper.db_pass,
            "batch_size": self._batch_size,
            "pickle_protocol": self._pickle_protocol,
            "read_only": self._read_only,
            "create_table": self._create_table,
        }

    def __repr__(self):
        """
        Return representative string for this class.

        :return: Representative string for this class.
        :rtype: str

        """
        return super(PostgresKeyValueStore, self).__repr__() \
            % ("table_name: %s, key_col: %s, value_col: %s, "
               "db_name: %s, db_host: %s, db_port: %s, db_user: %s, "
               "db_pass: %s, batch_size: %d, pickle_protocol: %d, "
               "read_only: %s, create_table: %s"
               % (self._table_name, self._key_col, self._value_col,
                  self._psql_helper.db_name, self._psql_helper.db_host,
                  self._psql_helper.db_port, self._psql_helper.db_user,
                  self._psql_helper.db_pass, self._batch_size,
                  self._pickle_protocol, self._read_only, self._create_table))

    def count(self):
        """
        :return: The number of key-value relationships in this store.
        :rtype: int | long
        """
        def cb(cur):
            cur.execute(
                self.SqlTemplates.SELECT_TMPL.format(
                    query='count(%s)' % self._key_col,
                    table_name=self._table_name,
                ))

        return list(
            self._psql_helper.single_execute(cb, yield_result_rows=True))[0][0]

    def keys(self):
        """
        :return: Iterator over keys in this store.
        :rtype: collections.Iterator[collections.Hashable]
        """
        def cb(cur):
            cur.execute(
                self.SqlTemplates.SELECT_TMPL.format(
                    query=self._key_col,
                    table_name=self._table_name,
                ))

        # We can use a named cursor because this is a select statement as well
        # as server table size may be large.
        for r in self._psql_helper.single_execute(cb,
                                                  yield_result_rows=True,
                                                  named=True):
            # Convert from buffer -> string -> python
            yield self._bin_to_py(r[0])

    def values(self):
        """
        :return: Iterator over values in this store. Values are not guaranteed
            to be in any particular order.
        :rtype: collections.Iterator[object]
        """
        def cb(cur):
            cur.execute(
                self.SqlTemplates.SELECT_TMPL.format(
                    query=self._value_col,
                    table_name=self._table_name,
                ))

        for r in self._psql_helper.single_execute(cb,
                                                  yield_result_rows=True,
                                                  named=True):
            # Convert from buffer -> string -> python
            yield self._bin_to_py(r[0])

    def is_read_only(self):
        """
        :return: True if this instance is read-only and False if it is not.
        :rtype: bool
        """
        return self._read_only

    def has(self, key):
        """
        Check if this store has a value for the given key.

        :param key: Key to check for a value for.
        :type key: collections.Hashable

        :return: If this store has a value for the given key.
        :rtype: bool

        """
        super(PostgresKeyValueStore, self).has(key)

        # Try to select based on given key value. If any rows are returned,
        # there is clearly a key that matches.
        q = self.SqlTemplates.SELECT_LIKE_TMPL.format(
            query='true',
            table_name=self._table_name,
            key_col=self._key_col,
        )

        def cb(cur):
            cur.execute(q, {'key_like': self._py_to_bin(key)})

        return bool(
            list(self._psql_helper.single_execute(cb, yield_result_rows=True)))

    def add(self, key, value):
        """
        Add a key-value pair to this store.

        :param key: Key for the value. Must be hashable.
        :type key: collections.Hashable

        :param value: Python object to store.
        :type value: object

        :raises ReadOnlyError: If this instance is marked as read-only.

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).add(key, value)

        q = self.SqlTemplates.UPSERT_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
            value_col=self._value_col,
        )
        v = {
            'key': self._py_to_bin(key),
            'val': self._py_to_bin(value),
        }

        def cb(cur):
            cur.execute(q, v)

        list(self._psql_helper.single_execute(cb))
        return self

    def add_many(self, d):
        """
        Add multiple key-value pairs at a time into this store as represented in
        the provided dictionary `d`.

        :param d: Dictionary of key-value pairs to add to this store.
        :type d: dict[collections.Hashable, object]

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).add_many(d)

        q = self.SqlTemplates.UPSERT_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
            value_col=self._value_col,
        )

        # Iterator over transformed inputs into values for statement.
        def val_iter():
            for key, val in six.iteritems(d):
                yield {
                    'key': self._py_to_bin(key),
                    'val': self._py_to_bin(val)
                }

        def cb(cur, v_batch):
            psycopg2.extras.execute_batch(cur,
                                          q,
                                          v_batch,
                                          page_size=self._batch_size)

        list(self._psql_helper.batch_execute(val_iter(), cb, self._batch_size))
        return self

    def remove(self, key):
        """
        Remove a single key-value entry.

        :param key: Key to remove.
        :type key: collections.Hashable

        :raises ReadOnlyError: If this instance is marked as read-only.
        :raises KeyError: The given key is not present in this store and no
            default value given.

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).remove(key)
        if key not in self:
            raise KeyError(key)

        q = self.SqlTemplates.DELETE_LIKE_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
        )
        v = dict(key_like=self._py_to_bin(key))

        def cb(cursor):
            cursor.execute(q, v)

        list(self._psql_helper.single_execute(cb))
        return self

    def _check_contained_keys(self, keys):
        """
        Check if the table contains the following keys.

        :param set keys: Keys to check for.

        :return: An set of keys NOT present in the table.
        :rtype: set[collections.Hashable]
        """
        def key_like_iter():
            for k_ in keys:
                yield self._py_to_bin(k_)

        has_many_q = self.SqlTemplates.SELECT_MANY_TMPL.format(
            query=self._key_col,
            table_name=self._table_name,
            key_col=self._key_col,
        )

        # Keys found in table
        matched_keys = set()

        def cb(cursor, batch):
            cursor.execute(has_many_q, {'key_tuple': tuple(batch)})
            matched_keys.update(self._bin_to_py(r[0]) for r in cursor)

        list(
            self._psql_helper.batch_execute(key_like_iter(), cb,
                                            self._batch_size))

        return keys - matched_keys

    def remove_many(self, keys):
        """
        Remove multiple keys and associated values.

        :param keys: Iterable of keys to remove.  If this is empty this method
            does nothing.
        :type keys: collections.Iterable[collections.Hashable]

        :raises ReadOnlyError: If this instance is marked as read-only.
        :raises KeyError: The given key is not present in this store and no
            default value given.  The store is not modified if any key is
            invalid.

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).remove_many(keys)
        keys = set(keys)

        # Check that all keys requested for removal are contained in our table
        # before attempting to remove any of them.
        key_diff = self._check_contained_keys(keys)
        # If we're trying to remove a key not in our table, appropriately raise
        # a KeyError.
        if key_diff:
            if len(key_diff) == 1:
                raise KeyError(list(key_diff)[0])
            else:
                raise KeyError(key_diff)

        # Proceed with removal
        def key_like_iter():
            """ Iterator over query value sets. """
            for k_ in keys:
                yield self._py_to_bin(k_)

        del_q = self.SqlTemplates.DELETE_LIKE_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
        )

        def del_cb(cursor, v_batch):
            # Execute the query with a list of value dicts.
            psycopg2.extras.execute_batch(cursor,
                                          del_q, [{
                                              'key_like': k
                                          } for k in v_batch],
                                          page_size=self._batch_size)

        list(
            self._psql_helper.batch_execute(key_like_iter(), del_cb,
                                            self._batch_size))
        return self

    def get(self, key, default=NO_DEFAULT_VALUE):
        """
        Get the value for the given key.

        *NOTE:* **Implementing sub-classes are responsible for raising a
        ``KeyError`` where appropriate.**

        :param key: Key to get the value of.
        :type key: collections.Hashable

        :param default: Optional default value if the given key is not present
            in this store. This may be any value except for the
            ``NO_DEFAULT_VALUE`` constant (custom anonymous class instance).
        :type default: object

        :raises KeyError: The given key is not present in this store and no
            default value given.

        :return: Deserialized python object stored for the given key.
        :rtype: object

        """
        q = self.SqlTemplates.SELECT_LIKE_TMPL.format(
            query=self._value_col,
            table_name=self._table_name,
            key_col=self._key_col,
        )
        v = {'key_like': self._py_to_bin(key)}

        def cb(cur):
            cur.execute(q, v)

        rows = list(
            self._psql_helper.single_execute(cb, yield_result_rows=True))
        # If no rows and no default, raise KeyError.
        if len(rows) == 0:
            if default is NO_DEFAULT_VALUE:
                raise KeyError(key)
            else:
                return default
        return self._bin_to_py(rows[0][0])

    def get_many(self, keys, default=NO_DEFAULT_VALUE):
        """
        Get the values for the given keys.

        *NOTE:* **Implementing sub-classes are responsible for raising a
        ``KeyError`` where appropriate.**

        :param keys: The keys for which associated values are requested.
        :type keys: collections.Iterable[collections.Hashable]

        :param default: Optional default value if a given key is not present
            in this store. This may be any value except for the
            ``NO_DEFAULT_VALUE`` constant (custom anonymous class instance).
        :type default: object

        :raises KeyError: A given key is not present in this store and no
            default value given.

        :return: Iterable of deserialized python objects stored for the given
            keys in the order that the corresponding keys were provided.
        :rtype: collections.Iterable

        """
        sql_command_string = self.SqlTemplates.SELECT_MANY_TMPL.format(
            query=', '.join((self._key_col, self._value_col)),
            table_name=self._table_name,
            key_col=self._key_col)
        keys = list(keys)

        sql_keys = tuple(self._py_to_bin(key_) for key_ in keys)
        sql_variables = {'key_tuple': sql_keys}

        def postgres_callback(cursor):
            cursor.execute(sql_command_string, sql_variables)

        retrieved_dict = {
            self._bin_to_py(row_[0]): self._bin_to_py(row_[1])
            for row_ in self._psql_helper.single_execute(
                postgres_callback, yield_result_rows=True)
        }

        if default is NO_DEFAULT_VALUE:
            for key_ in keys:
                yield retrieved_dict[key_]
        else:
            for key_ in keys:
                yield retrieved_dict.get(key_, default)

    def clear(self):
        """
        Clear this key-value store.

        *NOTE:* **Implementing sub-classes should call this super-method. This
        super method should not be considered a critical section for thread
        safety.**

        :raises ReadOnlyError: If this instance is marked as read-only.

        """
        q = self.SqlTemplates.DELETE_ALL.format(table_name=self._table_name)

        def cb(cur):
            cur.execute(q)

        list(self._psql_helper.single_execute(cb))
예제 #3
0
파일: postgres.py 프로젝트: Kitware/SMQTK
class PostgresKeyValueStore (KeyValueStore):
    """
    PostgreSQL-backed key-value storage.
    """

    class SqlTemplates (object):
        """
        Container for static PostgreSQL queries used by the containing class.
        """

        UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
            CREATE TABLE IF NOT EXISTS {table_name:s} (
              {key_col:s} BYTEA NOT NULL,
              {value_col:s} BYTEA NOT NULL,
              PRIMARY KEY ({key_col:s})
            );
        """)

        SELECT_TMPL = norm_psql_cmd_string("""
            SELECT {query:s} FROM {table_name:s};
        """)

        SELECT_LIKE_TMPL = norm_psql_cmd_string("""
            SELECT {query:s}
              FROM {table_name:s}
             WHERE {key_col:s} LIKE %(key_like)s
        """)

        SELECT_MANY_TMPL = norm_psql_cmd_string("""
            SELECT {query:s}
              FROM {table_name:s}
             WHERE {key_col:s} IN %(key_tuple)s
        """)

        UPSERT_TMPL = norm_psql_cmd_string("""
            WITH upsert AS (
              UPDATE {table_name:s}
                SET {value_col:s} = %(val)s
                WHERE {key_col:s} = %(key)s
                RETURNING *
              )
            INSERT INTO {table_name:s}
              ({key_col:s}, {value_col:s})
              SELECT %(key)s, %(val)s
                WHERE NOT EXISTS (SELECT * FROM upsert)
        """)

        DELETE_LIKE_TMPL = norm_psql_cmd_string("""
            DELETE FROM {table_name:s}
            WHERE {key_col:s} LIKE %(key_like)s 
        """)

        DELETE_ALL = norm_psql_cmd_string("""
            DELETE FROM {table_name:s}
        """)

    @classmethod
    def is_usable(cls):
        return psycopg2 is not None

    def __init__(self, table_name="data_set",
                 key_col='key', value_col='value', db_name='postgres',
                 db_host=None, db_port=None, db_user=None, db_pass=None,
                 batch_size=1000, pickle_protocol=-1,
                 read_only=False, create_table=True):
        """
        Initialize a PostgreSQL-backed data set instance.

        :param table_name: Name of the table to use.
        :type table_name: str

        :param key_col: Name of the column containing the UUID signatures.
        :type key_col: str

        :param value_col: Name of the table column that will contain
            serialized elements.
        :type value_col: str

        :param db_name: The name of the database to connect to.
        :type db_name: str

        :param db_host: Host address of the Postgres server. If None, we
            assume the server is on the local machine and use the UNIX socket.
            This might be a required field on Windows machines (not tested yet).
        :type db_host: str | None

        :param db_port: Port the Postgres server is exposed on. If None, we
            assume the default port (5423).
        :type db_port: int | None

        :param db_user: Postgres user to connect as. If None, postgres
            defaults to using the current accessing user account name on the
            operating system.
        :type db_user: str | None

        :param db_pass: Password for the user we're connecting as. This may be
            None if no password is to be used.
        :type db_pass: str | None

        :param batch_size: For queries that handle sending or
            receiving many queries at a time, batch queries based on this size.
            If this is None, then no batching occurs.

            The advantage of batching is that it reduces the memory impact for
            queries dealing with a very large number of elements (don't have to
            store the full query for all elements in RAM), but the transaction
            will be some amount slower due to splitting the query into multiple
            transactions.
        :type batch_size: int | None

        :param pickle_protocol: Pickling protocol to use. We will use -1 by
            default (latest version, probably binary).
        :type pickle_protocol: int

        :param read_only: Only allow read actions against this index.
            Modification actions will throw a ReadOnlyError exceptions.
        :type read_only: bool

        :param create_table: If this instance should try to create the storing
            table before actions are performed against it when not set to be
            read-only. If the configured user does not have sufficient
            permissions to create the table and it does not currently exist, an
            exception will be raised.
        :type create_table: bool

        """
        super(PostgresKeyValueStore, self).__init__()

        self._table_name = table_name
        self._key_col = key_col
        self._value_col = value_col

        self._batch_size = batch_size
        self._pickle_protocol = pickle_protocol
        self._read_only = bool(read_only)
        self._create_table = create_table

        # Checking parameters where necessary
        if self._batch_size is not None:
            self._batch_size = int(self._batch_size)
            assert self._batch_size > 0, \
                "A given batch size must be greater than 0 in size " \
                "(given: %d)." % self._batch_size
        assert -1 <= self._pickle_protocol <= 2, \
            ("Given pickle protocol is not in the known valid range [-1, 2]. "
             "Given: %s." % self._pickle_protocol)

        # helper structure for SQL operations.
        self._psql_helper = PsqlConnectionHelper(
            db_name, db_host, db_port, db_user, db_pass,
            itersize=batch_size,
            table_upsert_lock=PSQL_TABLE_CREATE_RLOCK,
        )

        # Only set table upsert if not read-only.
        if not self._read_only and self._create_table:
            # NOT read-only, so allow table upsert.
            self._psql_helper.set_table_upsert_sql(
                self.SqlTemplates.UPSERT_TABLE_TMPL.format(
                    table_name=self._table_name,
                    key_col=self._key_col,
                    value_col=self._value_col
                )
            )

    @staticmethod
    def _py_to_bin(k):
        """
        Convert a python hashable value into psycopg2.Binary via pickle.

        :param k: Python object instance to be converted into a
            ``psycopg2.Binary`` instance via ``pickle`` serialization.
        :type k: object

        :return: ``psycopg2.Binary`` buffer instance to use for insertion into
            or query against a table.
        :rtype: psycopg2.Binary

        """
        return psycopg2.Binary(pickle.dumps(k))

    @staticmethod
    def _bin_to_py(b):
        """
        Un-"translate" psycopg2.Binary value (buffer) to a python type.

        :param b: ``psycopg2.Binary`` buffer instance as retrieved from a
            PostgreSQL query.
        :type b: psycopg2.Binary

        :return: Python object instance as loaded via pickle from the given
            ``psycopg2.Binary`` buffer.
        :rtype: object

        """
        return pickle.loads(bytes(b))

    def get_config(self):
        """
        Return a JSON-compliant dictionary that could be passed to this class's
        ``from_config`` method to produce an instance with identical
        configuration.

        In the common case, this involves naming the keys of the dictionary
        based on the initialization argument names as if it were to be passed
        to the constructor via dictionary expansion.

        :return: JSON type compliant configuration dictionary.
        :rtype: dict

        """
        return {
            "table_name": self._table_name,
            "key_col": self._key_col,
            "value_col": self._value_col,

            "db_name": self._psql_helper.db_name,
            "db_host": self._psql_helper.db_host,
            "db_port": self._psql_helper.db_port,
            "db_user": self._psql_helper.db_user,
            "db_pass": self._psql_helper.db_pass,

            "batch_size": self._batch_size,
            "pickle_protocol": self._pickle_protocol,
            "read_only": self._read_only,
            "create_table": self._create_table,
        }

    def __repr__(self):
        """
        Return representative string for this class.

        :return: Representative string for this class.
        :rtype: str

        """
        return super(PostgresKeyValueStore, self).__repr__() \
            % ("table_name: %s, key_col: %s, value_col: %s, "
               "db_name: %s, db_host: %s, db_port: %s, db_user: %s, "
               "db_pass: %s, batch_size: %d, pickle_protocol: %d, "
               "read_only: %s, create_table: %s"
               % (self._table_name, self._key_col, self._value_col,
                  self._psql_helper.db_name, self._psql_helper.db_host,
                  self._psql_helper.db_port, self._psql_helper.db_user,
                  self._psql_helper.db_pass, self._batch_size,
                  self._pickle_protocol, self._read_only, self._create_table))

    def count(self):
        """
        :return: The number of key-value relationships in this store.
        :rtype: int | long
        """
        def cb(cur):
            cur.execute(self.SqlTemplates.SELECT_TMPL.format(
                query='count(%s)' % self._key_col,
                table_name=self._table_name,
            ))
        return list(self._psql_helper.single_execute(
            cb, yield_result_rows=True
        ))[0][0]

    def keys(self):
        """
        :return: Iterator over keys in this store.
        :rtype: collections.Iterator[collections.Hashable]
        """
        def cb(cur):
            cur.execute(self.SqlTemplates.SELECT_TMPL.format(
                query=self._key_col,
                table_name=self._table_name,
            ))
        # We can use a named cursor because this is a select statement as well
        # as server table size may be large.
        for r in self._psql_helper.single_execute(cb, yield_result_rows=True,
                                                  named=True):
            # Convert from buffer -> string -> python
            yield self._bin_to_py(r[0])

    def values(self):
        """
        :return: Iterator over values in this store. Values are not guaranteed
            to be in any particular order.
        :rtype: collections.Iterator[object]
        """
        def cb(cur):
            cur.execute(self.SqlTemplates.SELECT_TMPL.format(
                query=self._value_col,
                table_name=self._table_name,
            ))
        for r in self._psql_helper.single_execute(cb, yield_result_rows=True,
                                                  named=True):
            # Convert from buffer -> string -> python
            yield self._bin_to_py(r[0])

    def is_read_only(self):
        """
        :return: True if this instance is read-only and False if it is not.
        :rtype: bool
        """
        return self._read_only

    def has(self, key):
        """
        Check if this store has a value for the given key.

        :param key: Key to check for a value for.
        :type key: collections.Hashable

        :return: If this store has a value for the given key.
        :rtype: bool

        """
        super(PostgresKeyValueStore, self).has(key)

        # Try to select based on given key value. If any rows are returned,
        # there is clearly a key that matches.
        q = self.SqlTemplates.SELECT_LIKE_TMPL.format(
            query='true',
            table_name=self._table_name,
            key_col=self._key_col,
        )

        def cb(cur):
            cur.execute(q, {'key_like': self._py_to_bin(key)})
        return bool(list(self._psql_helper.single_execute(
            cb, yield_result_rows=True
        )))

    def add(self, key, value):
        """
        Add a key-value pair to this store.

        :param key: Key for the value. Must be hashable.
        :type key: collections.Hashable

        :param value: Python object to store.
        :type value: object

        :raises ReadOnlyError: If this instance is marked as read-only.

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).add(key, value)

        q = self.SqlTemplates.UPSERT_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
            value_col=self._value_col,
        )
        v = {
            'key': self._py_to_bin(key),
            'val': self._py_to_bin(value),
        }

        def cb(cur):
            cur.execute(q, v)

        list(self._psql_helper.single_execute(cb))
        return self

    def add_many(self, d):
        """
        Add multiple key-value pairs at a time into this store as represented in
        the provided dictionary `d`.

        :param d: Dictionary of key-value pairs to add to this store.
        :type d: dict[collections.Hashable, object]

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).add_many(d)

        q = self.SqlTemplates.UPSERT_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
            value_col=self._value_col,
        )

        # Iterator over transformed inputs into values for statement.
        def val_iter():
            for key, val in six.iteritems(d):
                yield {
                    'key': self._py_to_bin(key),
                    'val': self._py_to_bin(val)
                }

        def cb(cur, v_batch):
            psycopg2.extras.execute_batch(cur, q, v_batch,
                                          page_size=self._batch_size)

        list(self._psql_helper.batch_execute(val_iter(), cb, self._batch_size))
        return self

    def remove(self, key):
        """
        Remove a single key-value entry.

        :param key: Key to remove.
        :type key: collections.Hashable

        :raises ReadOnlyError: If this instance is marked as read-only.
        :raises KeyError: The given key is not present in this store and no
            default value given.

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).remove(key)
        if key not in self:
            raise KeyError(key)

        q = self.SqlTemplates.DELETE_LIKE_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
        )
        v = dict(
            key_like=self._py_to_bin(key)
        )

        def cb(cursor):
            cursor.execute(q, v)

        list(self._psql_helper.single_execute(cb))
        return self

    def _check_contained_keys(self, keys):
        """
        Check if the table contains the following keys.

        :param set keys: Keys to check for.

        :return: An set of keys NOT present in the table.
        :rtype: set[collections.Hashable]
        """
        def key_like_iter():
            for k_ in keys:
                yield self._py_to_bin(k_)

        has_many_q = self.SqlTemplates.SELECT_MANY_TMPL.format(
            query=self._key_col,
            table_name=self._table_name,
            key_col=self._key_col,
        )

        # Keys found in table
        matched_keys = set()

        def cb(cursor, batch):
            cursor.execute(has_many_q, {'key_tuple': tuple(batch)})
            matched_keys.update(self._bin_to_py(r[0]) for r in cursor)

        list(self._psql_helper.batch_execute(key_like_iter(), cb,
                                             self._batch_size))

        return keys - matched_keys

    def remove_many(self, keys):
        """
        Remove multiple keys and associated values.

        :param keys: Iterable of keys to remove.  If this is empty this method
            does nothing.
        :type keys: collections.Iterable[collections.Hashable]

        :raises ReadOnlyError: If this instance is marked as read-only.
        :raises KeyError: The given key is not present in this store and no
            default value given.  The store is not modified if any key is
            invalid.

        :return: Self.
        :rtype: KeyValueStore

        """
        super(PostgresKeyValueStore, self).remove_many(keys)
        keys = set(keys)

        # Check that all keys requested for removal are contained in our table
        # before attempting to remove any of them.
        key_diff = self._check_contained_keys(keys)
        # If we're trying to remove a key not in our table, appropriately raise
        # a KeyError.
        if key_diff:
            if len(key_diff) == 1:
                raise KeyError(list(key_diff)[0])
            else:
                raise KeyError(key_diff)

        # Proceed with removal
        def key_like_iter():
            """ Iterator over query value sets. """
            for k_ in keys:
                yield self._py_to_bin(k_)

        del_q = self.SqlTemplates.DELETE_LIKE_TMPL.format(
            table_name=self._table_name,
            key_col=self._key_col,
        )

        def del_cb(cursor, v_batch):
            # Execute the query with a list of value dicts.
            psycopg2.extras.execute_batch(cursor, del_q,
                                          [{'key_like': k} for k in v_batch],
                                          page_size=self._batch_size)

        list(self._psql_helper.batch_execute(key_like_iter(), del_cb,
                                             self._batch_size))
        return self

    def get(self, key, default=NO_DEFAULT_VALUE):
        """
        Get the value for the given key.

        *NOTE:* **Implementing sub-classes are responsible for raising a
        ``KeyError`` where appropriate.**

        :param key: Key to get the value of.
        :type key: collections.Hashable

        :param default: Optional default value if the given key is not present
            in this store. This may be any value except for the
            ``NO_DEFAULT_VALUE`` constant (custom anonymous class instance).
        :type default: object

        :raises KeyError: The given key is not present in this store and no
            default value given.

        :return: Deserialized python object stored for the given key.
        :rtype: object

        """
        q = self.SqlTemplates.SELECT_LIKE_TMPL.format(
            query=self._value_col,
            table_name=self._table_name,
            key_col=self._key_col,
        )
        v = {'key_like': self._py_to_bin(key)}

        def cb(cur):
            cur.execute(q, v)

        rows = list(self._psql_helper.single_execute(
            cb, yield_result_rows=True
        ))
        # If no rows and no default, raise KeyError.
        if len(rows) == 0:
            if default is NO_DEFAULT_VALUE:
                raise KeyError(key)
            else:
                return default
        return self._bin_to_py(rows[0][0])

    def clear(self):
        """
        Clear this key-value store.

        *NOTE:* **Implementing sub-classes should call this super-method. This
        super method should not be considered a critical section for thread
        safety.**

        :raises ReadOnlyError: If this instance is marked as read-only.

        """
        q = self.SqlTemplates.DELETE_ALL.format(table_name=self._table_name)

        def cb(cur):
            cur.execute(q)

        list(self._psql_helper.single_execute(cb))
예제 #4
0
파일: postgres.py 프로젝트: Kitware/SMQTK
class PostgresDescriptorIndex (DescriptorIndex):
    """
    DescriptorIndex implementation that stored DescriptorElement references in
    a PostgreSQL database.

    A ``PostgresDescriptorIndex`` effectively controls the entire table. Thus
    a ``clear()`` call will remove everything from the table.

    PostgreSQL version support:
        - 9.4

    Table format:
        <uuid col>      TEXT NOT NULL
        <element col>   BYTEA NOT NULL

        <uuid_col> should be the primary key (we assume unique).

    We require that the no column labels not be 'true' for the use of a value
    return shortcut.

    """

    #
    # The following are SQL query templates. The string formatting using {}'s
    # is used to fill in the query before using it in an execute with instance
    # specific values. The ``%()s`` formatting is special for the execute where-
    # by psycopg2 will fill in the values appropriately as specified in a second
    # dictionary argument to ``cursor.execute(query, value_dict)``.
    #
    UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
        CREATE TABLE IF NOT EXISTS {table_name:s} (
          {uuid_col:s} TEXT NOT NULL,
          {element_col:s} BYTEA NOT NULL,
          PRIMARY KEY ({uuid_col:s})
        );
    """)

    SELECT_TMPL = norm_psql_cmd_string("""
        SELECT {col:s}
          FROM {table_name:s}
    """)

    SELECT_LIKE_TMPL = norm_psql_cmd_string("""
        SELECT {element_col:s}
          FROM {table_name:s}
         WHERE {uuid_col:s} like %(uuid_like)s
    """)

    # So we can ensure we get back elements in specified order
    #   - reference [1]
    SELECT_MANY_ORDERED_TMPL = norm_psql_cmd_string("""
        SELECT {table_name:s}.{element_col:s}
          FROM {table_name:s}
          JOIN (
            SELECT *
            FROM unnest(%(uuid_list)s) with ordinality
          ) AS __ordering__ ({uuid_col:s}, {uuid_col:s}_order)
            ON {table_name:s}.{uuid_col:s} = __ordering__.{uuid_col:s}
          ORDER BY __ordering__.{uuid_col:s}_order
    """)

    UPSERT_TMPL = norm_psql_cmd_string("""
        WITH upsert AS (
          UPDATE {table_name:s}
            SET {element_col:s} = %(element_val)s
            WHERE {uuid_col:s} = %(uuid_val)s
            RETURNING *
          )
        INSERT INTO {table_name:s}
          ({uuid_col:s}, {element_col:s})
          SELECT %(uuid_val)s, %(element_val)s
            WHERE NOT EXISTS (SELECT * FROM upsert)
    """)

    DELETE_LIKE_TMPL = norm_psql_cmd_string("""
        DELETE FROM {table_name:s}
              WHERE {uuid_col:s} like %(uuid_like)s
    """)

    DELETE_MANY_TMPL = norm_psql_cmd_string("""
        DELETE FROM {table_name:s}
              WHERE {uuid_col:s} in %(uuid_tuple)s
          RETURNING uid
    """)

    @classmethod
    def is_usable(cls):
        return psycopg2 is not None

    def __init__(self, table_name='descriptor_index', uuid_col='uid',
                 element_col='element',
                 db_name='postgres', db_host=None, db_port=None, db_user=None,
                 db_pass=None, multiquery_batch_size=1000, pickle_protocol=-1,
                 read_only=False, create_table=True):
        """
        Initialize index instance.

        :param table_name: Name of the table to use.
        :type table_name: str

        :param uuid_col: Name of the column containing the UUID signatures.
        :type uuid_col: str

        :param element_col: Name of the table column that will contain
            serialized elements.
        :type element_col: str

        :param db_name: The name of the database to connect to.
        :type db_name: str

        :param db_host: Host address of the Postgres server. If None, we
            assume the server is on the local machine and use the UNIX socket.
            This might be a required field on Windows machines (not tested yet).
        :type db_host: str | None

        :param db_port: Port the Postgres server is exposed on. If None, we
            assume the default port (5423).
        :type db_port: int | None

        :param db_user: Postgres user to connect as. If None, postgres
            defaults to using the current accessing user account name on the
            operating system.
        :type db_user: str | None

        :param db_pass: Password for the user we're connecting as. This may be
            None if no password is to be used.
        :type db_pass: str | None

        :param multiquery_batch_size: For queries that handle sending or
            receiving many queries at a time, batch queries based on this size.
            If this is None, then no batching occurs.

            The advantage of batching is that it reduces the memory impact for
            queries dealing with a very large number of elements (don't have to
            store the full query for all elements in RAM), but the transaction
            will be some amount slower due to splitting the query into multiple
            transactions.
        :type multiquery_batch_size: int | None

        :param pickle_protocol: Pickling protocol to use. We will use -1 by
            default (latest version, probably binary).
        :type pickle_protocol: int

        :param read_only: Only allow read actions against this index.
            Modification actions will throw a ReadOnlyError exceptions.
        :type read_only: bool

        :param create_table: If this instance should try to create the storing
            table before actions are performed against it when not set to be
            read-only. If the configured user does not have sufficient
            permissions to create the table and it does not currently exist, an
            exception will be raised.
        :type create_table: bool

        """
        super(PostgresDescriptorIndex, self).__init__()

        self.table_name = table_name
        self.uuid_col = uuid_col
        self.element_col = element_col

        self.multiquery_batch_size = multiquery_batch_size
        self.pickle_protocol = pickle_protocol
        self.read_only = bool(read_only)
        self.create_table = create_table

        # Checking parameters where necessary
        if self.multiquery_batch_size is not None:
            self.multiquery_batch_size = int(self.multiquery_batch_size)
            assert self.multiquery_batch_size > 0, \
                "A given batch size must be greater than 0 in size " \
                "(given: %d)." % self.multiquery_batch_size
        assert -1 <= self.pickle_protocol <= 2, \
            ("Given pickle protocol is not in the known valid range. Given: %s"
             % self.pickle_protocol)

        self.psql_helper = PsqlConnectionHelper(db_name, db_host, db_port,
                                                db_user, db_pass,
                                                self.multiquery_batch_size,
                                                PSQL_TABLE_CREATE_RLOCK)
        if not self.read_only and self.create_table:
            self.psql_helper.set_table_upsert_sql(
                self.UPSERT_TABLE_TMPL.format(
                    table_name=self.table_name,
                    uuid_col=self.uuid_col,
                    element_col=self.element_col,
                )
            )

    def get_config(self):
        return {
            "table_name": self.table_name,
            "uuid_col": self.uuid_col,
            "element_col": self.element_col,

            "db_name": self.psql_helper.db_name,
            "db_host": self.psql_helper.db_host,
            "db_port": self.psql_helper.db_port,
            "db_user": self.psql_helper.db_user,
            "db_pass": self.psql_helper.db_pass,

            "multiquery_batch_size": self.multiquery_batch_size,
            "pickle_protocol": self.pickle_protocol,
            "read_only": self.read_only,
            "create_table": self.create_table,
        }

    def count(self):
        """
        :return: Number of descriptor elements stored in this index.
        :rtype: int | long
        """
        # Just count UUID column to limit data read.
        q = self.SELECT_TMPL.format(
            col='count(%s)' % self.uuid_col,
            table_name=self.table_name,
        )

        def exec_hook(cur):
            cur.execute(q)

        # There's only going to be one row returned with one element in it.
        return list(self.psql_helper.single_execute(
            exec_hook, yield_result_rows=True
        ))[0][0]

    def clear(self):
        """
        Clear this descriptor index's entries.
        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only index.")

        q = self.DELETE_LIKE_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )

        def exec_hook(cur):
            cur.execute(q, {'uuid_like': '%'})

        list(self.psql_helper.single_execute(exec_hook))

    def has_descriptor(self, uuid):
        """
        Check if a DescriptorElement with the given UUID exists in this index.

        :param uuid: UUID to query for
        :type uuid: collections.Hashable

        :return: True if a DescriptorElement with the given UUID exists in this
            index, or False if not.
        :rtype: bool

        """
        q = self.SELECT_LIKE_TMPL.format(
            # hacking return value to something simple
            element_col='true',
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )

        def exec_hook(cur):
            cur.execute(q, {'uuid_like': str(uuid)})

        # Should either yield one or zero rows
        return bool(list(self.psql_helper.single_execute(
            exec_hook, yield_result_rows=True
        )))

    def add_descriptor(self, descriptor):
        """
        Add a descriptor to this index.

        Adding the same descriptor multiple times should not add multiple copies
        of the descriptor in the index (based on UUID). Added descriptors
        overwrite indexed descriptors based on UUID.

        :param descriptor: Descriptor to index.
        :type descriptor: smqtk.representation.DescriptorElement

        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only index.")

        q = self.UPSERT_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
            element_col=self.element_col,
        )
        v = {
            'uuid_val': str(descriptor.uuid()),
            'element_val': psycopg2.Binary(
                pickle.dumps(descriptor, self.pickle_protocol)
            )
        }

        def exec_hook(cur):
            cur.execute(q, v)

        list(self.psql_helper.single_execute(exec_hook))

    def add_many_descriptors(self, descriptors):
        """
        Add multiple descriptors at one time.

        Adding the same descriptor multiple times should not add multiple copies
        of the descriptor in the index (based on UUID). Added descriptors
        overwrite indexed descriptors based on UUID.

        :param descriptors: Iterable of descriptor instances to add to this
            index.
        :type descriptors:
            collections.Iterable[smqtk.representation.DescriptorElement]

        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only index.")

        q = self.UPSERT_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
            element_col=self.element_col,
        )

        # Transform input into
        def iter_elements():
            for d in descriptors:
                yield {
                    'uuid_val': str(d.uuid()),
                    'element_val': psycopg2.Binary(
                        pickle.dumps(d, self.pickle_protocol)
                    )
                }

        def exec_hook(cur, batch):
            cur.executemany(q, batch)

        self._log.debug("Adding many descriptors")
        list(self.psql_helper.batch_execute(iter_elements(), exec_hook,
                                            self.multiquery_batch_size))

    def get_descriptor(self, uuid):
        """
        Get the descriptor in this index that is associated with the given UUID.

        :param uuid: UUID of the DescriptorElement to get.
        :type uuid: collections.Hashable

        :raises KeyError: The given UUID doesn't associate to a
            DescriptorElement in this index.

        :return: DescriptorElement associated with the queried UUID.
        :rtype: smqtk.representation.DescriptorElement

        """
        q = self.SELECT_LIKE_TMPL.format(
            element_col=self.element_col,
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        v = {'uuid_like': str(uuid)}

        def eh(c):
            c.execute(q, v)
            if c.rowcount == 0:
                raise KeyError(uuid)
            elif c.rowcount != 1:
                raise RuntimeError("Found more than one entry for the given "
                                   "uuid '%s' (got: %d)"
                                   % (uuid, c.rowcount))

        r = list(self.psql_helper.single_execute(eh, yield_result_rows=True))
        return pickle.loads(str(r[0][0]))

    def get_many_descriptors(self, uuids):
        """
        Get an iterator over descriptors associated to given descriptor UUIDs.

        :param uuids: Iterable of descriptor UUIDs to query for.
        :type uuids: collections.Iterable[collections.Hashable]

        :raises KeyError: A given UUID doesn't associate with a
            DescriptorElement in this index.

        :return: Iterator of descriptors associated to given uuid values.
        :rtype: __generator[smqtk.representation.DescriptorElement]

        """
        q = self.SELECT_MANY_ORDERED_TMPL.format(
            table_name=self.table_name,
            element_col=self.element_col,
            uuid_col=self.uuid_col,
        )

        # Cache UUIDs received in order so we can check when we miss one in
        # order to raise a KeyError.
        uuid_order = []

        def iterelems():
            for uid in uuids:
                uuid_order.append(uid)
                yield str(uid)

        def exec_hook(cur, batch):
            v = {'uuid_list': batch}
            # self._log.debug('query: %s', cur.mogrify(q, v))
            cur.execute(q, v)

        self._log.debug("Getting many descriptors")
        # The SELECT_MANY_ORDERED_TMPL query ensures that elements returned are
        #   in the UUID order given to this method. Thus, if the iterated UUIDs
        #   and iterated return rows do not exactly line up, the query join
        #   failed to match a query UUID to something in the database.
        #   - We also check that the number of rows we got back is the same
        #     as elements yielded, else there were trailing UUIDs that did not
        #     match anything in the database.
        g = self.psql_helper.batch_execute(iterelems(), exec_hook,
                                           self.multiquery_batch_size,
                                           yield_result_rows=True)
        i = 0
        for r, expected_uuid in zip(g, uuid_order):
            d = pickle.loads(str(r[0]))
            if d.uuid() != expected_uuid:
                raise KeyError(expected_uuid)
            yield d
            i += 1

        if len(uuid_order) != i:
            # just report the first one that's bad
            raise KeyError(uuid_order[i])

    def remove_descriptor(self, uuid):
        """
        Remove a descriptor from this index by the given UUID.

        :param uuid: UUID of the DescriptorElement to remove.
        :type uuid: collections.Hashable

        :raises KeyError: The given UUID doesn't associate to a
            DescriptorElement in this index.

        """
        if self.read_only:
            raise ReadOnlyError("Cannot remove from a read-only index.")

        q = self.DELETE_LIKE_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        v = {'uuid_like': str(uuid)}

        def execute(c):
            c.execute(q, v)
            # Nothing deleted if rowcount == 0
            # (otherwise 1 when deleted a thing)
            if c.rowcount == 0:
                raise KeyError(uuid)

        list(self.psql_helper.single_execute(execute))

    def remove_many_descriptors(self, uuids):
        """
        Remove descriptors associated to given descriptor UUIDs from this index.

        :param uuids: Iterable of descriptor UUIDs to remove.
        :type uuids: collections.Iterable[collections.Hashable]

        :raises KeyError: A given UUID doesn't associate with a
            DescriptorElement in this index.

        """
        if self.read_only:
            raise ReadOnlyError("Cannot remove from a read-only index.")

        q = self.DELETE_MANY_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        str_uuid_set = set(str(uid) for uid in uuids)
        v = {'uuid_tuple': tuple(str_uuid_set)}

        def execute(c):
            c.execute(q, v)

            # Check query UUIDs against rows that would actually be deleted.
            deleted_uuid_set = set(r[0] for r in c.fetchall())
            for uid in str_uuid_set:
                if uid not in deleted_uuid_set:
                    raise KeyError(uid)

        list(self.psql_helper.single_execute(execute))

    def iterkeys(self):
        """
        Return an iterator over indexed descriptor keys, which are their UUIDs.
        :rtype: collections.Iterator[collections.Hashable]
        """
        # Getting UUID through the element because the UUID might not be a
        # string type, and the true type is encoded with the DescriptorElement
        # instance.
        for d in self.iterdescriptors():
            yield d.uuid()

    def iterdescriptors(self):
        """
        Return an iterator over indexed descriptor element instances.
        :rtype: collections.Iterator[smqtk.representation.DescriptorElement]
        """
        def execute(c):
            c.execute(self.SELECT_TMPL.format(
                col=self.element_col,
                table_name=self.table_name
            ))

        #: :type: __generator
        execution_results = self.psql_helper.single_execute(
            execute, yield_result_rows=True, named=True
        )
        for r in execution_results:
            d = pickle.loads(str(r[0]))
            yield d

    def iteritems(self):
        """
        Return an iterator over indexed descriptor key and instance pairs.
        :rtype: collections.Iterator[(collections.Hashable,
                                      smqtk.representation.DescriptorElement)]
        """
        for d in self.iterdescriptors():
            yield d.uuid(), d
예제 #5
0
class PostgresDescriptorSet(DescriptorSet):
    """
    DescriptorSet implementation that stored DescriptorElement references in
    a PostgreSQL database.

    A ``PostgresDescriptorSet`` effectively controls the entire table. Thus
    a ``clear()`` call will remove everything from the table.

    PostgreSQL version support:
        - 9.4

    Table format:
        <uuid col>      TEXT NOT NULL
        <element col>   BYTEA NOT NULL

        <uuid_col> should be the primary key (we assume unique).

    We require that the no column labels not be 'true' for the use of a value
    return shortcut.

    """

    #
    # The following are SQL query templates. The string formatting using {}'s
    # is used to fill in the query before using it in an execute with instance
    # specific values. The ``%()s`` formatting is special for the execute
    # where-by psycopg2 will fill in the values appropriately as specified in a
    # second dictionary argument to ``cursor.execute(query, value_dict)``.
    #
    UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
        CREATE TABLE IF NOT EXISTS {table_name:s} (
          {uuid_col:s} TEXT NOT NULL,
          {element_col:s} BYTEA NOT NULL,
          PRIMARY KEY ({uuid_col:s})
        );
    """)

    SELECT_TMPL = norm_psql_cmd_string("""
        SELECT {col:s}
          FROM {table_name:s}
    """)

    SELECT_LIKE_TMPL = norm_psql_cmd_string("""
        SELECT {element_col:s}
          FROM {table_name:s}
         WHERE {uuid_col:s} like %(uuid_like)s
    """)

    # So we can ensure we get back elements in specified order
    #   - reference [1]
    SELECT_MANY_ORDERED_TMPL = norm_psql_cmd_string("""
        SELECT {table_name:s}.{element_col:s}
          FROM {table_name:s}
          JOIN (
            SELECT *
            FROM unnest(%(uuid_list)s) with ordinality
          ) AS __ordering__ ({uuid_col:s}, {uuid_col:s}_order)
            ON {table_name:s}.{uuid_col:s} = __ordering__.{uuid_col:s}
          ORDER BY __ordering__.{uuid_col:s}_order
    """)

    UPSERT_TMPL = norm_psql_cmd_string("""
        WITH upsert AS (
          UPDATE {table_name:s}
            SET {element_col:s} = %(element_val)s
            WHERE {uuid_col:s} = %(uuid_val)s
            RETURNING *
          )
        INSERT INTO {table_name:s}
          ({uuid_col:s}, {element_col:s})
          SELECT %(uuid_val)s, %(element_val)s
            WHERE NOT EXISTS (SELECT * FROM upsert)
    """)

    DELETE_LIKE_TMPL = norm_psql_cmd_string("""
        DELETE FROM {table_name:s}
              WHERE {uuid_col:s} like %(uuid_like)s
    """)

    DELETE_MANY_TMPL = norm_psql_cmd_string("""
        DELETE FROM {table_name:s}
              WHERE {uuid_col:s} in %(uuid_tuple)s
          RETURNING uid
    """)

    @classmethod
    def is_usable(cls):
        return psycopg2 is not None

    def __init__(self,
                 table_name='descriptor_set',
                 uuid_col='uid',
                 element_col='element',
                 db_name='postgres',
                 db_host=None,
                 db_port=None,
                 db_user=None,
                 db_pass=None,
                 multiquery_batch_size=1000,
                 pickle_protocol=-1,
                 read_only=False,
                 create_table=True):
        """
        Initialize set instance.

        :param table_name: Name of the table to use.
        :type table_name: str

        :param uuid_col: Name of the column containing the UUID signatures.
        :type uuid_col: str

        :param element_col: Name of the table column that will contain
            serialized elements.
        :type element_col: str

        :param db_name: The name of the database to connect to.
        :type db_name: str

        :param db_host: Host address of the Postgres server. If None, we
            assume the server is on the local machine and use the UNIX socket.
            This might be a required field on Windows machines (not tested yet).
        :type db_host: str | None

        :param db_port: Port the Postgres server is exposed on. If None, we
            assume the default port (5423).
        :type db_port: int | None

        :param db_user: Postgres user to connect as. If None, postgres
            defaults to using the current accessing user account name on the
            operating system.
        :type db_user: str | None

        :param db_pass: Password for the user we're connecting as. This may be
            None if no password is to be used.
        :type db_pass: str | None

        :param multiquery_batch_size: For queries that handle sending or
            receiving many queries at a time, batch queries based on this size.
            If this is None, then no batching occurs.

            The advantage of batching is that it reduces the memory impact for
            queries dealing with a very large number of elements (don't have to
            store the full query for all elements in RAM), but the transaction
            will be some amount slower due to splitting the query into multiple
            transactions.
        :type multiquery_batch_size: int | None

        :param pickle_protocol: Pickling protocol to use. We will use -1 by
            default (latest version, probably binary).
        :type pickle_protocol: int

        :param read_only: Only allow read actions against this set.
            Modification actions will throw a ReadOnlyError exceptions.
        :type read_only: bool

        :param create_table: If this instance should try to create the storing
            table before actions are performed against it when not set to be
            read-only. If the configured user does not have sufficient
            permissions to create the table and it does not currently exist, an
            exception will be raised.
        :type create_table: bool

        """
        super(PostgresDescriptorSet, self).__init__()

        self.table_name = table_name
        self.uuid_col = uuid_col
        self.element_col = element_col

        self.multiquery_batch_size = multiquery_batch_size
        self.pickle_protocol = pickle_protocol
        self.read_only = bool(read_only)
        self.create_table = create_table

        # Checking parameters where necessary
        if self.multiquery_batch_size is not None:
            self.multiquery_batch_size = int(self.multiquery_batch_size)
            assert self.multiquery_batch_size > 0, \
                "A given batch size must be greater than 0 in size " \
                "(given: %d)." % self.multiquery_batch_size
        assert -1 <= self.pickle_protocol <= 2, \
            ("Given pickle protocol is not in the known valid range. Given: %s"
             % self.pickle_protocol)

        self.psql_helper = PsqlConnectionHelper(db_name, db_host, db_port,
                                                db_user, db_pass,
                                                self.multiquery_batch_size,
                                                PSQL_TABLE_CREATE_RLOCK)
        if not self.read_only and self.create_table:
            self.psql_helper.set_table_upsert_sql(
                self.UPSERT_TABLE_TMPL.format(
                    table_name=self.table_name,
                    uuid_col=self.uuid_col,
                    element_col=self.element_col,
                ))

    def get_config(self):
        return {
            "table_name": self.table_name,
            "uuid_col": self.uuid_col,
            "element_col": self.element_col,
            "db_name": self.psql_helper.db_name,
            "db_host": self.psql_helper.db_host,
            "db_port": self.psql_helper.db_port,
            "db_user": self.psql_helper.db_user,
            "db_pass": self.psql_helper.db_pass,
            "multiquery_batch_size": self.multiquery_batch_size,
            "pickle_protocol": self.pickle_protocol,
            "read_only": self.read_only,
            "create_table": self.create_table,
        }

    def count(self):
        """
        :return: Number of descriptor elements stored in this set.
        :rtype: int | long
        """
        # Just count UUID column to limit data read.
        q = self.SELECT_TMPL.format(
            col='count(%s)' % self.uuid_col,
            table_name=self.table_name,
        )

        def exec_hook(cur):
            cur.execute(q)

        # There's only going to be one row returned with one element in it.
        return list(
            self.psql_helper.single_execute(exec_hook,
                                            yield_result_rows=True))[0][0]

    def clear(self):
        """
        Clear this descriptor set's entries.
        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only set.")

        q = self.DELETE_LIKE_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )

        def exec_hook(cur):
            cur.execute(q, {'uuid_like': '%'})

        list(self.psql_helper.single_execute(exec_hook))

    def has_descriptor(self, uuid):
        """
        Check if a DescriptorElement with the given UUID exists in this set.

        :param uuid: UUID to query for
        :type uuid: collections.Hashable

        :return: True if a DescriptorElement with the given UUID exists in this
            set, or False if not.
        :rtype: bool

        """
        q = self.SELECT_LIKE_TMPL.format(
            # hacking return value to something simple
            element_col='true',
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )

        def exec_hook(cur):
            cur.execute(q, {'uuid_like': str(uuid)})

        # Should either yield one or zero rows
        return bool(
            list(
                self.psql_helper.single_execute(exec_hook,
                                                yield_result_rows=True)))

    def add_descriptor(self, descriptor):
        """
        Add a descriptor to this set.

        Adding the same descriptor multiple times should not add multiple copies
        of the descriptor in the set (based on UUID). Added descriptors
        overwrite set descriptors based on UUID.

        :param descriptor: Descriptor to set.
        :type descriptor: smqtk.representation.DescriptorElement

        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only set.")

        q = self.UPSERT_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
            element_col=self.element_col,
        )
        v = {
            'uuid_val':
            str(descriptor.uuid()),
            'element_val':
            psycopg2.Binary(pickle.dumps(descriptor, self.pickle_protocol))
        }

        def exec_hook(cur):
            cur.execute(q, v)

        list(self.psql_helper.single_execute(exec_hook))

    def add_many_descriptors(self, descriptors):
        """
        Add multiple descriptors at one time.

        Adding the same descriptor multiple times should not add multiple copies
        of the descriptor in the set (based on UUID). Added descriptors
        overwrite set descriptors based on UUID.

        :param descriptors: Iterable of descriptor instances to add to this
            set.
        :type descriptors:
            collections.Iterable[smqtk.representation.DescriptorElement]

        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only set.")

        q = self.UPSERT_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
            element_col=self.element_col,
        )

        # Transform input into
        def iter_elements():
            for d in descriptors:
                yield {
                    'uuid_val':
                    str(d.uuid()),
                    'element_val':
                    psycopg2.Binary(pickle.dumps(d, self.pickle_protocol))
                }

        def exec_hook(cur, batch):
            cur.executemany(q, batch)

        self._log.debug("Adding many descriptors")
        list(
            self.psql_helper.batch_execute(iter_elements(), exec_hook,
                                           self.multiquery_batch_size))

    def get_descriptor(self, uuid):
        """
        Get the descriptor in this set that is associated with the given UUID.

        :param uuid: UUID of the DescriptorElement to get.
        :type uuid: collections.Hashable

        :raises KeyError: The given UUID doesn't associate to a
            DescriptorElement in this set.

        :return: DescriptorElement associated with the queried UUID.
        :rtype: smqtk.representation.DescriptorElement

        """
        q = self.SELECT_LIKE_TMPL.format(
            element_col=self.element_col,
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        v = {'uuid_like': str(uuid)}

        def eh(c):
            c.execute(q, v)
            if c.rowcount == 0:
                raise KeyError(uuid)
            elif c.rowcount != 1:
                raise RuntimeError("Found more than one entry for the given "
                                   "uuid '%s' (got: %d)" % (uuid, c.rowcount))

        r = list(self.psql_helper.single_execute(eh, yield_result_rows=True))
        return pickle.loads(bytes(r[0][0]))

    def get_many_descriptors(self, uuids):
        """
        Get an iterator over descriptors associated to given descriptor UUIDs.

        :param uuids: Iterable of descriptor UUIDs to query for.
        :type uuids: collections.Iterable[collections.Hashable]

        :raises KeyError: A given UUID doesn't associate with a
            DescriptorElement in this set.

        :return: Iterator of descriptors associated to given uuid values.
        :rtype: __generator[smqtk.representation.DescriptorElement]

        """
        q = self.SELECT_MANY_ORDERED_TMPL.format(
            table_name=self.table_name,
            element_col=self.element_col,
            uuid_col=self.uuid_col,
        )

        # Cache UUIDs received in order so we can check when we miss one in
        # order to raise a KeyError.
        uuid_order = []

        def iterelems():
            for uid in uuids:
                uuid_order.append(uid)
                yield str(uid)

        def exec_hook(cur, batch):
            v = {'uuid_list': batch}
            # self._log.debug('query: %s', cur.mogrify(q, v))
            cur.execute(q, v)

        self._log.debug("Getting many descriptors")
        # The SELECT_MANY_ORDERED_TMPL query ensures that elements returned are
        #   in the UUID order given to this method. Thus, if the iterated UUIDs
        #   and iterated return rows do not exactly line up, the query join
        #   failed to match a query UUID to something in the database.
        #   - We also check that the number of rows we got back is the same
        #     as elements yielded, else there were trailing UUIDs that did not
        #     match anything in the database.
        g = self.psql_helper.batch_execute(iterelems(),
                                           exec_hook,
                                           self.multiquery_batch_size,
                                           yield_result_rows=True)
        i = 0
        for r, expected_uuid in zip(g, uuid_order):
            d = pickle.loads(bytes(r[0]))
            if d.uuid() != expected_uuid:
                raise KeyError(expected_uuid)
            yield d
            i += 1

        if len(uuid_order) != i:
            # just report the first one that's bad
            raise KeyError(uuid_order[i])

    def remove_descriptor(self, uuid):
        """
        Remove a descriptor from this set by the given UUID.

        :param uuid: UUID of the DescriptorElement to remove.
        :type uuid: collections.Hashable

        :raises KeyError: The given UUID doesn't associate to a
            DescriptorElement in this set.

        """
        if self.read_only:
            raise ReadOnlyError("Cannot remove from a read-only set.")

        q = self.DELETE_LIKE_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        v = {'uuid_like': str(uuid)}

        def execute(c):
            c.execute(q, v)
            # Nothing deleted if rowcount == 0
            # (otherwise 1 when deleted a thing)
            if c.rowcount == 0:
                raise KeyError(uuid)

        list(self.psql_helper.single_execute(execute))

    def remove_many_descriptors(self, uuids):
        """
        Remove descriptors associated to given descriptor UUIDs from this set.

        :param uuids: Iterable of descriptor UUIDs to remove.
        :type uuids: collections.Iterable[collections.Hashable]

        :raises KeyError: A given UUID doesn't associate with a
            DescriptorElement in this set.

        """
        if self.read_only:
            raise ReadOnlyError("Cannot remove from a read-only set.")

        q = self.DELETE_MANY_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        str_uuid_set = set(str(uid) for uid in uuids)
        v = {'uuid_tuple': tuple(str_uuid_set)}

        def execute(c):
            c.execute(q, v)

            # Check query UUIDs against rows that would actually be deleted.
            deleted_uuid_set = set(r[0] for r in c.fetchall())
            for uid in str_uuid_set:
                if uid not in deleted_uuid_set:
                    raise KeyError(uid)

        list(self.psql_helper.single_execute(execute))

    def iterkeys(self):
        """
        Return an iterator over set descriptor keys, which are their UUIDs.
        :rtype: collections.Iterator[collections.Hashable]
        """
        # Getting UUID through the element because the UUID might not be a
        # string type, and the true type is encoded with the DescriptorElement
        # instance.
        for d in self.iterdescriptors():
            yield d.uuid()

    def iterdescriptors(self):
        """
        Return an iterator over set descriptor element instances.
        :rtype: collections.Iterator[smqtk.representation.DescriptorElement]
        """
        def execute(c):
            c.execute(
                self.SELECT_TMPL.format(col=self.element_col,
                                        table_name=self.table_name))

        #: :type: __generator
        execution_results = self.psql_helper.single_execute(
            execute, yield_result_rows=True, named=True)
        for r in execution_results:
            d = pickle.loads(bytes(r[0]))
            yield d

    def iteritems(self):
        """
        Return an iterator over set descriptor key and instance pairs.
        :rtype: collections.Iterator[(collections.Hashable,
                                      smqtk.representation.DescriptorElement)]
        """
        for d in self.iterdescriptors():
            yield d.uuid(), d