コード例 #1
0
ファイル: postgres.py プロジェクト: fabgithub/SMQTK
    class SqlTemplates(object):
        """
        Container for static PostgreSQL queries used by the containing class.
        """

        UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
            CREATE TABLE IF NOT EXISTS {table_name:s} (
              {key_col:s} BYTEA NOT NULL,
              {value_col:s} BYTEA NOT NULL,
              PRIMARY KEY ({key_col:s})
            );
        """)

        SELECT_TMPL = norm_psql_cmd_string("""
            SELECT {query:s} FROM {table_name:s};
        """)

        SELECT_LIKE_TMPL = norm_psql_cmd_string("""
            SELECT {query:s}
              FROM {table_name:s}
             WHERE {key_col:s} like %(key_like)s
        """)

        UPSERT_TMPL = norm_psql_cmd_string("""
            WITH upsert AS (
              UPDATE {table_name:s}
                SET {value_col:s} = %(val)s
                WHERE {key_col:s} = %(key)s
                RETURNING *
              )
            INSERT INTO {table_name:s}
              ({key_col:s}, {value_col:s})
              SELECT %(key)s, %(val)s
                WHERE NOT EXISTS (SELECT * FROM upsert)
        """)

        DELETE_ALL = norm_psql_cmd_string("""
            DELETE FROM {table_name:s}
        """)
コード例 #2
0
    class SqlTemplates(object):
        """
        Container for static PostgreSQL queries used by the containing class.
        """

        UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
            CREATE TABLE IF NOT EXISTS {table_name:s} (
              {key_col:s} BYTEA NOT NULL,
              {value_col:s} BYTEA NOT NULL,
              PRIMARY KEY ({key_col:s})
            );
        """)

        SELECT_TMPL = norm_psql_cmd_string("""
            SELECT {query:s} FROM {table_name:s};
        """)

        SELECT_LIKE_TMPL = norm_psql_cmd_string("""
            SELECT {query:s}
              FROM {table_name:s}
             WHERE {key_col:s} LIKE %(key_like)s
        """)

        SELECT_MANY_TMPL = norm_psql_cmd_string("""
            SELECT {query:s}
              FROM {table_name:s}
             WHERE {key_col:s} IN %(key_tuple)s
        """)

        UPSERT_TMPL = norm_psql_cmd_string("""
            INSERT INTO {table_name:s} ({key_col:s}, {value_col:s})
                VALUES (%(key)s, %(val)s)
                ON CONFLICT ({key_col:s})
                    DO UPDATE
                        SET {value_col:s} = EXCLUDED.{value_col:s}
        """)

        DELETE_LIKE_TMPL = norm_psql_cmd_string("""
            DELETE FROM {table_name:s}
            WHERE {key_col:s} LIKE %(key_like)s
        """)

        DELETE_ALL = norm_psql_cmd_string("""
            DELETE FROM {table_name:s}
        """)
コード例 #3
0
    class CommandTemplates(object):
        """ Encapsulation of command templates. """

        # Upsert table for storage if desired
        #
        # Format params:
        # - table_name
        # - id_col
        # - sha1_col
        # - mime_col
        # - byte_col
        UPSERT_TABLE = norm_psql_cmd_string("""
            CREATE TABLE IF NOT EXISTS {table_name:s} (
              {id_col:s}   TEXT NOT NULL,
              {sha1_col:s} TEXT NOT NULL,
              {mime_col:s} TEXT NOT NULL,
              {byte_col:s} BYTEA NOT NULL,
              PRIMARY KEY ({id_col:s})
            );
        """)

        # Select ``col`` for a given entry ID.
        #
        # Query Format params:
        # - col
        # - table_name
        # - id_col
        #
        # Value params:
        # - id_val
        SELECT = norm_psql_cmd_string("""
            SELECT {col:s}
              FROM {table_name:s}
              WHERE {id_col:s} = %(id_val)s
            ;
        """)

        # Upsert content-type/data for a uid
        #
        # Query Format params:
        # - table_name
        # - id_col
        # - sha1_col
        # - mime_col
        # - byte_col
        #
        # Value params:
        # - id_val
        # - sha1_val
        # - mime_val
        # - byte_val
        #
        # SQL format from:
        #   https://hashrocket.com/blog/posts/upsert-records-with-postgresql-9-5
        #
        UPSERT_DATA = norm_psql_cmd_string("""
            INSERT INTO {table_name:s} ({id_col:s}, {sha1_col:s}, {mime_col:s}, {byte_col:s})
                VALUES ( %(id_val)s, %(sha1_val)s, %(mime_val)s, %(byte_val)s )
                ON CONFLICT ({id_col:s})
                    DO UPDATE
                        SET ({sha1_col:s}, {mime_col:s}, {byte_col:s})
                          = (EXCLUDED.{sha1_col:s}, EXCLUDED.{mime_col:s}, EXCLUDED.{byte_col:s})
            ;
        """)

        # Same as ``UPSERT_DATA`` but does not set the mimetype on an update.
        # This is meant to atomically update the byte data without changing the
        # existing mimetype.
        UPSERT_DATA_NO_MIME = norm_psql_cmd_string("""
            INSERT INTO {table_name:s} ({id_col:s}, {sha1_col:s}, {mime_col:s}, {byte_col:s})
                VALUES ( %(id_val)s, %(sha1_val)s, %(mime_val)s, %(byte_val)s )
                ON CONFLICT ({id_col:s})
                    DO UPDATE
                        SET ({sha1_col:s}, {byte_col:s})
                          = (EXCLUDED.{sha1_col:s}, EXCLUDED.{byte_col:s})
            ;
        """)
コード例 #4
0
class PostgresDescriptorSet(DescriptorSet):
    """
    DescriptorSet implementation that stored DescriptorElement references in
    a PostgreSQL database.

    A ``PostgresDescriptorSet`` effectively controls the entire table. Thus
    a ``clear()`` call will remove everything from the table.

    PostgreSQL version support:
        - 9.4

    Table format:
        <uuid col>      TEXT NOT NULL
        <element col>   BYTEA NOT NULL

        <uuid_col> should be the primary key (we assume unique).

    We require that the no column labels not be 'true' for the use of a value
    return shortcut.

    """

    #
    # The following are SQL query templates. The string formatting using {}'s
    # is used to fill in the query before using it in an execute with instance
    # specific values. The ``%()s`` formatting is special for the execute
    # where-by psycopg2 will fill in the values appropriately as specified in a
    # second dictionary argument to ``cursor.execute(query, value_dict)``.
    #
    UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
        CREATE TABLE IF NOT EXISTS {table_name:s} (
          {uuid_col:s} TEXT NOT NULL,
          {element_col:s} BYTEA NOT NULL,
          PRIMARY KEY ({uuid_col:s})
        );
    """)

    SELECT_TMPL = norm_psql_cmd_string("""
        SELECT {col:s}
          FROM {table_name:s}
    """)

    SELECT_LIKE_TMPL = norm_psql_cmd_string("""
        SELECT {element_col:s}
          FROM {table_name:s}
         WHERE {uuid_col:s} like %(uuid_like)s
    """)

    # So we can ensure we get back elements in specified order
    #   - reference [1]
    SELECT_MANY_ORDERED_TMPL = norm_psql_cmd_string("""
        SELECT {table_name:s}.{element_col:s}
          FROM {table_name:s}
          JOIN (
            SELECT *
            FROM unnest(%(uuid_list)s) with ordinality
          ) AS __ordering__ ({uuid_col:s}, {uuid_col:s}_order)
            ON {table_name:s}.{uuid_col:s} = __ordering__.{uuid_col:s}
          ORDER BY __ordering__.{uuid_col:s}_order
    """)

    UPSERT_TMPL = norm_psql_cmd_string("""
        WITH upsert AS (
          UPDATE {table_name:s}
            SET {element_col:s} = %(element_val)s
            WHERE {uuid_col:s} = %(uuid_val)s
            RETURNING *
          )
        INSERT INTO {table_name:s}
          ({uuid_col:s}, {element_col:s})
          SELECT %(uuid_val)s, %(element_val)s
            WHERE NOT EXISTS (SELECT * FROM upsert)
    """)

    DELETE_LIKE_TMPL = norm_psql_cmd_string("""
        DELETE FROM {table_name:s}
              WHERE {uuid_col:s} like %(uuid_like)s
    """)

    DELETE_MANY_TMPL = norm_psql_cmd_string("""
        DELETE FROM {table_name:s}
              WHERE {uuid_col:s} in %(uuid_tuple)s
          RETURNING uid
    """)

    @classmethod
    def is_usable(cls):
        return psycopg2 is not None

    def __init__(self,
                 table_name='descriptor_set',
                 uuid_col='uid',
                 element_col='element',
                 db_name='postgres',
                 db_host=None,
                 db_port=None,
                 db_user=None,
                 db_pass=None,
                 multiquery_batch_size=1000,
                 pickle_protocol=-1,
                 read_only=False,
                 create_table=True):
        """
        Initialize set instance.

        :param table_name: Name of the table to use.
        :type table_name: str

        :param uuid_col: Name of the column containing the UUID signatures.
        :type uuid_col: str

        :param element_col: Name of the table column that will contain
            serialized elements.
        :type element_col: str

        :param db_name: The name of the database to connect to.
        :type db_name: str

        :param db_host: Host address of the Postgres server. If None, we
            assume the server is on the local machine and use the UNIX socket.
            This might be a required field on Windows machines (not tested yet).
        :type db_host: str | None

        :param db_port: Port the Postgres server is exposed on. If None, we
            assume the default port (5423).
        :type db_port: int | None

        :param db_user: Postgres user to connect as. If None, postgres
            defaults to using the current accessing user account name on the
            operating system.
        :type db_user: str | None

        :param db_pass: Password for the user we're connecting as. This may be
            None if no password is to be used.
        :type db_pass: str | None

        :param multiquery_batch_size: For queries that handle sending or
            receiving many queries at a time, batch queries based on this size.
            If this is None, then no batching occurs.

            The advantage of batching is that it reduces the memory impact for
            queries dealing with a very large number of elements (don't have to
            store the full query for all elements in RAM), but the transaction
            will be some amount slower due to splitting the query into multiple
            transactions.
        :type multiquery_batch_size: int | None

        :param pickle_protocol: Pickling protocol to use. We will use -1 by
            default (latest version, probably binary).
        :type pickle_protocol: int

        :param read_only: Only allow read actions against this set.
            Modification actions will throw a ReadOnlyError exceptions.
        :type read_only: bool

        :param create_table: If this instance should try to create the storing
            table before actions are performed against it when not set to be
            read-only. If the configured user does not have sufficient
            permissions to create the table and it does not currently exist, an
            exception will be raised.
        :type create_table: bool

        """
        super(PostgresDescriptorSet, self).__init__()

        self.table_name = table_name
        self.uuid_col = uuid_col
        self.element_col = element_col

        self.multiquery_batch_size = multiquery_batch_size
        self.pickle_protocol = pickle_protocol
        self.read_only = bool(read_only)
        self.create_table = create_table

        # Checking parameters where necessary
        if self.multiquery_batch_size is not None:
            self.multiquery_batch_size = int(self.multiquery_batch_size)
            assert self.multiquery_batch_size > 0, \
                "A given batch size must be greater than 0 in size " \
                "(given: %d)." % self.multiquery_batch_size
        assert -1 <= self.pickle_protocol <= 2, \
            ("Given pickle protocol is not in the known valid range. Given: %s"
             % self.pickle_protocol)

        self.psql_helper = PsqlConnectionHelper(db_name, db_host, db_port,
                                                db_user, db_pass,
                                                self.multiquery_batch_size,
                                                PSQL_TABLE_CREATE_RLOCK)
        if not self.read_only and self.create_table:
            self.psql_helper.set_table_upsert_sql(
                self.UPSERT_TABLE_TMPL.format(
                    table_name=self.table_name,
                    uuid_col=self.uuid_col,
                    element_col=self.element_col,
                ))

    def get_config(self):
        return {
            "table_name": self.table_name,
            "uuid_col": self.uuid_col,
            "element_col": self.element_col,
            "db_name": self.psql_helper.db_name,
            "db_host": self.psql_helper.db_host,
            "db_port": self.psql_helper.db_port,
            "db_user": self.psql_helper.db_user,
            "db_pass": self.psql_helper.db_pass,
            "multiquery_batch_size": self.multiquery_batch_size,
            "pickle_protocol": self.pickle_protocol,
            "read_only": self.read_only,
            "create_table": self.create_table,
        }

    def count(self):
        """
        :return: Number of descriptor elements stored in this set.
        :rtype: int | long
        """
        # Just count UUID column to limit data read.
        q = self.SELECT_TMPL.format(
            col='count(%s)' % self.uuid_col,
            table_name=self.table_name,
        )

        def exec_hook(cur):
            cur.execute(q)

        # There's only going to be one row returned with one element in it.
        return list(
            self.psql_helper.single_execute(exec_hook,
                                            yield_result_rows=True))[0][0]

    def clear(self):
        """
        Clear this descriptor set's entries.
        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only set.")

        q = self.DELETE_LIKE_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )

        def exec_hook(cur):
            cur.execute(q, {'uuid_like': '%'})

        list(self.psql_helper.single_execute(exec_hook))

    def has_descriptor(self, uuid):
        """
        Check if a DescriptorElement with the given UUID exists in this set.

        :param uuid: UUID to query for
        :type uuid: collections.Hashable

        :return: True if a DescriptorElement with the given UUID exists in this
            set, or False if not.
        :rtype: bool

        """
        q = self.SELECT_LIKE_TMPL.format(
            # hacking return value to something simple
            element_col='true',
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )

        def exec_hook(cur):
            cur.execute(q, {'uuid_like': str(uuid)})

        # Should either yield one or zero rows
        return bool(
            list(
                self.psql_helper.single_execute(exec_hook,
                                                yield_result_rows=True)))

    def add_descriptor(self, descriptor):
        """
        Add a descriptor to this set.

        Adding the same descriptor multiple times should not add multiple copies
        of the descriptor in the set (based on UUID). Added descriptors
        overwrite set descriptors based on UUID.

        :param descriptor: Descriptor to set.
        :type descriptor: smqtk.representation.DescriptorElement

        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only set.")

        q = self.UPSERT_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
            element_col=self.element_col,
        )
        v = {
            'uuid_val':
            str(descriptor.uuid()),
            'element_val':
            psycopg2.Binary(pickle.dumps(descriptor, self.pickle_protocol))
        }

        def exec_hook(cur):
            cur.execute(q, v)

        list(self.psql_helper.single_execute(exec_hook))

    def add_many_descriptors(self, descriptors):
        """
        Add multiple descriptors at one time.

        Adding the same descriptor multiple times should not add multiple copies
        of the descriptor in the set (based on UUID). Added descriptors
        overwrite set descriptors based on UUID.

        :param descriptors: Iterable of descriptor instances to add to this
            set.
        :type descriptors:
            collections.Iterable[smqtk.representation.DescriptorElement]

        """
        if self.read_only:
            raise ReadOnlyError("Cannot clear a read-only set.")

        q = self.UPSERT_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
            element_col=self.element_col,
        )

        # Transform input into
        def iter_elements():
            for d in descriptors:
                yield {
                    'uuid_val':
                    str(d.uuid()),
                    'element_val':
                    psycopg2.Binary(pickle.dumps(d, self.pickle_protocol))
                }

        def exec_hook(cur, batch):
            cur.executemany(q, batch)

        self._log.debug("Adding many descriptors")
        list(
            self.psql_helper.batch_execute(iter_elements(), exec_hook,
                                           self.multiquery_batch_size))

    def get_descriptor(self, uuid):
        """
        Get the descriptor in this set that is associated with the given UUID.

        :param uuid: UUID of the DescriptorElement to get.
        :type uuid: collections.Hashable

        :raises KeyError: The given UUID doesn't associate to a
            DescriptorElement in this set.

        :return: DescriptorElement associated with the queried UUID.
        :rtype: smqtk.representation.DescriptorElement

        """
        q = self.SELECT_LIKE_TMPL.format(
            element_col=self.element_col,
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        v = {'uuid_like': str(uuid)}

        def eh(c):
            c.execute(q, v)
            if c.rowcount == 0:
                raise KeyError(uuid)
            elif c.rowcount != 1:
                raise RuntimeError("Found more than one entry for the given "
                                   "uuid '%s' (got: %d)" % (uuid, c.rowcount))

        r = list(self.psql_helper.single_execute(eh, yield_result_rows=True))
        return pickle.loads(bytes(r[0][0]))

    def get_many_descriptors(self, uuids):
        """
        Get an iterator over descriptors associated to given descriptor UUIDs.

        :param uuids: Iterable of descriptor UUIDs to query for.
        :type uuids: collections.Iterable[collections.Hashable]

        :raises KeyError: A given UUID doesn't associate with a
            DescriptorElement in this set.

        :return: Iterator of descriptors associated to given uuid values.
        :rtype: __generator[smqtk.representation.DescriptorElement]

        """
        q = self.SELECT_MANY_ORDERED_TMPL.format(
            table_name=self.table_name,
            element_col=self.element_col,
            uuid_col=self.uuid_col,
        )

        # Cache UUIDs received in order so we can check when we miss one in
        # order to raise a KeyError.
        uuid_order = []

        def iterelems():
            for uid in uuids:
                uuid_order.append(uid)
                yield str(uid)

        def exec_hook(cur, batch):
            v = {'uuid_list': batch}
            # self._log.debug('query: %s', cur.mogrify(q, v))
            cur.execute(q, v)

        self._log.debug("Getting many descriptors")
        # The SELECT_MANY_ORDERED_TMPL query ensures that elements returned are
        #   in the UUID order given to this method. Thus, if the iterated UUIDs
        #   and iterated return rows do not exactly line up, the query join
        #   failed to match a query UUID to something in the database.
        #   - We also check that the number of rows we got back is the same
        #     as elements yielded, else there were trailing UUIDs that did not
        #     match anything in the database.
        g = self.psql_helper.batch_execute(iterelems(),
                                           exec_hook,
                                           self.multiquery_batch_size,
                                           yield_result_rows=True)
        i = 0
        for r, expected_uuid in zip(g, uuid_order):
            d = pickle.loads(bytes(r[0]))
            if d.uuid() != expected_uuid:
                raise KeyError(expected_uuid)
            yield d
            i += 1

        if len(uuid_order) != i:
            # just report the first one that's bad
            raise KeyError(uuid_order[i])

    def remove_descriptor(self, uuid):
        """
        Remove a descriptor from this set by the given UUID.

        :param uuid: UUID of the DescriptorElement to remove.
        :type uuid: collections.Hashable

        :raises KeyError: The given UUID doesn't associate to a
            DescriptorElement in this set.

        """
        if self.read_only:
            raise ReadOnlyError("Cannot remove from a read-only set.")

        q = self.DELETE_LIKE_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        v = {'uuid_like': str(uuid)}

        def execute(c):
            c.execute(q, v)
            # Nothing deleted if rowcount == 0
            # (otherwise 1 when deleted a thing)
            if c.rowcount == 0:
                raise KeyError(uuid)

        list(self.psql_helper.single_execute(execute))

    def remove_many_descriptors(self, uuids):
        """
        Remove descriptors associated to given descriptor UUIDs from this set.

        :param uuids: Iterable of descriptor UUIDs to remove.
        :type uuids: collections.Iterable[collections.Hashable]

        :raises KeyError: A given UUID doesn't associate with a
            DescriptorElement in this set.

        """
        if self.read_only:
            raise ReadOnlyError("Cannot remove from a read-only set.")

        q = self.DELETE_MANY_TMPL.format(
            table_name=self.table_name,
            uuid_col=self.uuid_col,
        )
        str_uuid_set = set(str(uid) for uid in uuids)
        v = {'uuid_tuple': tuple(str_uuid_set)}

        def execute(c):
            c.execute(q, v)

            # Check query UUIDs against rows that would actually be deleted.
            deleted_uuid_set = set(r[0] for r in c.fetchall())
            for uid in str_uuid_set:
                if uid not in deleted_uuid_set:
                    raise KeyError(uid)

        list(self.psql_helper.single_execute(execute))

    def iterkeys(self):
        """
        Return an iterator over set descriptor keys, which are their UUIDs.
        :rtype: collections.Iterator[collections.Hashable]
        """
        # Getting UUID through the element because the UUID might not be a
        # string type, and the true type is encoded with the DescriptorElement
        # instance.
        for d in self.iterdescriptors():
            yield d.uuid()

    def iterdescriptors(self):
        """
        Return an iterator over set descriptor element instances.
        :rtype: collections.Iterator[smqtk.representation.DescriptorElement]
        """
        def execute(c):
            c.execute(
                self.SELECT_TMPL.format(col=self.element_col,
                                        table_name=self.table_name))

        #: :type: __generator
        execution_results = self.psql_helper.single_execute(
            execute, yield_result_rows=True, named=True)
        for r in execution_results:
            d = pickle.loads(bytes(r[0]))
            yield d

    def iteritems(self):
        """
        Return an iterator over set descriptor key and instance pairs.
        :rtype: collections.Iterator[(collections.Hashable,
                                      smqtk.representation.DescriptorElement)]
        """
        for d in self.iterdescriptors():
            yield d.uuid(), d
コード例 #5
0
ファイル: postgres.py プロジェクト: fabgithub/SMQTK
class PostgresDescriptorElement(DescriptorElement):
    """
    Descriptor element whose vector is stored in a Postgres database.

    We assume we will work with a Postgres version of at least 9.4 (due to
    versions tested).

    Efficient connection pooling may be achieved via external utilities like
    PGBounder.

    """

    ARRAY_DTYPE = numpy.float64

    UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
        CREATE TABLE IF NOT EXISTS {table_name:s} (
          {type_col:s} TEXT NOT NULL,
          {uuid_col:s} TEXT NOT NULL,
          {binary_col:s} BYTEA NOT NULL,
          PRIMARY KEY ({type_col:s}, {uuid_col:s})
        );
    """)

    SELECT_TMPL = norm_psql_cmd_string("""
        SELECT {binary_col:s}
          FROM {table_name:s}
          WHERE {type_col:s} = %(type_val)s
            AND {uuid_col:s} = %(uuid_val)s
        ;
    """)

    UPSERT_TMPL = norm_psql_cmd_string("""
        WITH upsert AS (
          UPDATE {table_name:s}
            SET {binary_col:s} = %(binary_val)s
            WHERE {type_col:s} = %(type_val)s
              AND {uuid_col:s} = %(uuid_val)s
            RETURNING *
          )
        INSERT INTO {table_name:s} ({type_col:s}, {uuid_col:s}, {binary_col:s})
          SELECT %(type_val)s, %(uuid_val)s, %(binary_val)s
            WHERE NOT EXISTS (SELECT * FROM upsert);
    """)

    @classmethod
    def is_usable(cls):
        if psycopg2 is None:
            cls.get_logger().warning("Not usable. Requires psycopg2 module")
            return False
        return True

    def __init__(self,
                 type_str,
                 uuid,
                 table_name='descriptors',
                 uuid_col='uid',
                 type_col='type_str',
                 binary_col='vector',
                 db_name='postgres',
                 db_host=None,
                 db_port=None,
                 db_user=None,
                 db_pass=None,
                 create_table=True):
        """
        Initialize new PostgresDescriptorElement attached to some database
        credentials.

        We require that storage tables treat uuid AND type string columns as
        primary keys. The type and uuid columns should be of the 'text' type.
        The binary column should be of the 'bytea' type.

        Default argument values assume a local PostgreSQL database with a table
        created via the
        ``etc/smqtk/postgres/descriptor_element/example_table_init.sql``
        file (relative to the SMQTK source tree or install root).

        NOTES:
            - Not all uuid types used here are necessarily of the ``uuid.UUID``
              type, thus the recommendation to use a ``text`` type for the
              column. For certain specific use cases they may be proper
              ``uuid.UUID`` instances or strings, but this cannot be generally
              assumed.

        :param type_str: Type of descriptor. This is usually the name of the
            content descriptor that generated this vector.
        :type type_str: str

        :param uuid: Unique ID reference of the descriptor.
        :type uuid: collections.Hashable

        :param table_name: String label of the database table to use.
        :type table_name: str

        :param uuid_col: The column label for descriptor UUID storage
        :type uuid_col: str

        :param type_col: The column label for descriptor type string storage.
        :type type_col: str

        :param binary_col: The column label for descriptor vector binary
            storage.
        :type binary_col: str

        :param db_host: Host address of the Postgres server. If None, we
            assume the server is on the local machine and use the UNIX socket.
            This might be a required field on Windows machines (not tested yet).
        :type db_host: str | None

        :param db_port: Port the Postgres server is exposed on. If None, we
            assume the default port (5423).
        :type db_port: int | None

        :param db_name: The name of the database to connect to.
        :type db_name: str

        :param db_user: Postgres user to connect as. If None, postgres
            defaults to using the current accessing user account name on the
            operating system.
        :type db_user: str | None

        :param db_pass: Password for the user we're connecting as. This may be
            None if no password is to be used.
        :type db_pass: str | None

        :param create_table: If this instance should try to create the storing
            table before actions are performed against it. If the configured
            user does not have sufficient permissions to create the table and it
            does not currently exist, an exception will be raised.
        :type create_table: bool

        """
        super(PostgresDescriptorElement, self).__init__(type_str, uuid)

        self.table_name = table_name
        self.uuid_col = uuid_col
        self.type_col = type_col
        self.binary_col = binary_col
        self.create_table = create_table

        self.db_name = db_name
        self.db_host = db_host
        self.db_port = db_port
        self.db_user = db_user
        self.db_pass = db_pass

    def get_config(self):
        return {
            "table_name": self.table_name,
            "uuid_col": self.uuid_col,
            "type_col": self.type_col,
            "binary_col": self.binary_col,
            "create_table": self.create_table,
            "db_name": self.db_name,
            "db_host": self.db_host,
            "db_port": self.db_port,
            "db_user": self.db_user,
            "db_pass": self.db_pass,
        }

    def _get_psql_connection(self):
        """
        :return: A new connection to the configured database
        :rtype: psycopg2._psycopg.connection
        """
        return psycopg2.connect(
            database=self.db_name,
            user=self.db_user,
            password=self.db_pass,
            host=self.db_host,
            port=self.db_port,
        )

    def _ensure_table(self, cursor):
        """
        Execute on psql connector cursor the table create-of-not-exists query.

        :param cursor: Connection active cursor.

        """
        if self.create_table:
            q_table_upsert = self.UPSERT_TABLE_TMPL.format(**dict(
                table_name=self.table_name,
                type_col=self.type_col,
                uuid_col=self.uuid_col,
                binary_col=self.binary_col,
            ))
            with PSQL_TABLE_CREATE_RLOCK:
                cursor.execute(q_table_upsert)
                cursor.connection.commit()

    def has_vector(self):
        """
        Check if the target database has a vector for our keys.

        This also returns True if we have a cached vector since there must have
        been a source vector to draw from if there is a cache of it.

        If a vector is cached, this resets the cache expiry timeout.

        :return: Whether or not this container current has a descriptor vector
            stored.
        :rtype: bool

        """
        # Very similar to vector query, but replacing vector binary return with
        # a true/null return. Save a little bit of time compared to testing
        # vector return.
        # OLD: return self.vector() is not None

        # Using static value 'true' for binary "column" to reduce data return
        # volume.
        q_select = self.SELECT_TMPL.format(
            **{
                'binary_col': 'true',
                'table_name': self.table_name,
                'type_col': self.type_col,
                'uuid_col': self.uuid_col,
            })
        q_select_values = {
            "type_val": self.type(),
            "uuid_val": str(self.uuid())
        }

        conn = self._get_psql_connection()
        cur = conn.cursor()

        try:
            self._ensure_table(cur)
            cur.execute(q_select, q_select_values)
            r = cur.fetchone()
            # For server cleaning (e.g. pgbouncer)
            conn.commit()
            return bool(r)
        except:
            conn.rollback()
            raise
        finally:
            cur.close()
            conn.close()

    def vector(self):
        """
        Return this element's vector, or None if we don't have one.

        :return: Get the stored descriptor vector as a numpy array. This returns
            None of there is no vector stored in this container.
        :rtype: numpy.core.multiarray.ndarray or None

        """
        q_select = self.SELECT_TMPL.format(
            **{
                "binary_col": self.binary_col,
                "table_name": self.table_name,
                "type_col": self.type_col,
                "uuid_col": self.uuid_col,
            })
        q_select_values = {
            "type_val": self.type(),
            "uuid_val": str(self.uuid())
        }

        conn = self._get_psql_connection()
        cur = conn.cursor()
        try:
            self._ensure_table(cur)
            cur.execute(q_select, q_select_values)

            r = cur.fetchone()
            conn.commit()

            if not r:
                return None
            else:
                b = r[0]
                v = numpy.frombuffer(b, self.ARRAY_DTYPE)
                return v
        except:
            conn.rollback()
            raise
        finally:
            cur.close()
            conn.close()

    def set_vector(self, new_vec):
        """
        Set the contained vector.

        If this container already stores a descriptor vector, this will
        overwrite it.

        If we are configured to use caching, and one has not been cached yet,
        then we cache the vector and start a thread to monitor access times and
        to remove the cache if the access timeout has expired.

        If a vector was already cached, this new vector replaces the old one,
        the vector database-side is replaced, and the cache expiry timeout is
        reset.

        :raises ValueError: ``new_vec`` was not a numpy ndarray.

        :param new_vec: New vector to contain. This must be a numpy array.
        :type new_vec: numpy.core.multiarray.ndarray

        """
        if not isinstance(new_vec, numpy.core.multiarray.ndarray):
            raise ValueError(
                "Input array for setting was not a numpy.ndarray! "
                "(given: %s)" % type(new_vec))

        if new_vec.dtype != self.ARRAY_DTYPE:
            new_vec = new_vec.astype(self.ARRAY_DTYPE)

        q_upsert = self.UPSERT_TMPL.strip().format(
            **{
                "table_name": self.table_name,
                "binary_col": self.binary_col,
                "type_col": self.type_col,
                "uuid_col": self.uuid_col,
            })
        q_upsert_values = {
            "binary_val": psycopg2.Binary(new_vec),
            "type_val": self.type(),
            "uuid_val": str(self.uuid()),
        }

        conn = self._get_psql_connection()
        cur = conn.cursor()
        try:
            self._ensure_table(cur)
            cur.execute(q_upsert, q_upsert_values)
            conn.commit()
        except:
            conn.rollback()
            raise
        finally:
            cur.close()
            conn.close()
コード例 #6
0
ファイル: postgres.py プロジェクト: spongezhang/SMQTK
class PostgresDescriptorElement(DescriptorElement):
    """
    Descriptor element whose vector is stored in a Postgres database.

    We assume we will work with a Postgres version of at least 9.4 (due to
    versions tested).

    Efficient connection pooling may be achieved via external utilities like
    PGBounder.

    """

    ARRAY_DTYPE = numpy.float64

    UPSERT_TABLE_TMPL = norm_psql_cmd_string("""
        CREATE TABLE IF NOT EXISTS {table_name:s} (
          {type_col:s} TEXT NOT NULL,
          {uuid_col:s} TEXT NOT NULL,
          {binary_col:s} BYTEA NOT NULL,
          PRIMARY KEY ({type_col:s}, {uuid_col:s})
        );
    """)

    SELECT_TMPL = norm_psql_cmd_string("""
        SELECT {binary_col:s}
          FROM {table_name:s}
          WHERE {type_col:s} = %(type_val)s
            AND {uuid_col:s} = %(uuid_val)s
        ;
    """)

    UPSERT_TMPL = norm_psql_cmd_string("""
        WITH upsert AS (
          UPDATE {table_name:s}
            SET {binary_col:s} = %(binary_val)s
            WHERE {type_col:s} = %(type_val)s
              AND {uuid_col:s} = %(uuid_val)s
            RETURNING *
          )
        INSERT INTO {table_name:s} ({type_col:s}, {uuid_col:s}, {binary_col:s})
          SELECT %(type_val)s, %(uuid_val)s, %(binary_val)s
            WHERE NOT EXISTS (SELECT * FROM upsert);
    """)

    @classmethod
    def is_usable(cls):
        if psycopg2 is None:
            cls.get_logger().warning("Not usable. Requires psycopg2 module")
            return False
        return True

    def __init__(self,
                 type_str,
                 uuid,
                 table_name='descriptors',
                 uuid_col='uid',
                 type_col='type_str',
                 binary_col='vector',
                 db_name='postgres',
                 db_host=None,
                 db_port=None,
                 db_user=None,
                 db_pass=None,
                 create_table=True):
        """
        Initialize new PostgresDescriptorElement attached to some database
        credentials.

        We require that storage tables treat uuid AND type string columns as
        primary keys. The type and uuid columns should be of the 'text' type.
        The binary column should be of the 'bytea' type.

        Default argument values assume a local PostgreSQL database with a table
        created via the
        ``etc/smqtk/postgres/descriptor_element/example_table_init.sql``
        file (relative to the SMQTK source tree or install root).

        NOTES:
            - Not all uuid types used here are necessarily of the ``uuid.UUID``
              type, thus the recommendation to use a ``text`` type for the
              column. For certain specific use cases they may be proper
              ``uuid.UUID`` instances or strings, but this cannot be generally
              assumed.

        :param type_str: Type of descriptor. This is usually the name of the
            content descriptor that generated this vector.
        :type type_str: str

        :param uuid: Unique ID reference of the descriptor.
        :type uuid: collections.Hashable

        :param table_name: String label of the database table to use.
        :type table_name: str

        :param uuid_col: The column label for descriptor UUID storage
        :type uuid_col: str

        :param type_col: The column label for descriptor type string storage.
        :type type_col: str

        :param binary_col: The column label for descriptor vector binary
            storage.
        :type binary_col: str

        :param db_host: Host address of the Postgres server. If None, we
            assume the server is on the local machine and use the UNIX socket.
            This might be a required field on Windows machines (not tested yet).
        :type db_host: str | None

        :param db_port: Port the Postgres server is exposed on. If None, we
            assume the default port (5423).
        :type db_port: int | None

        :param db_name: The name of the database to connect to.
        :type db_name: str

        :param db_user: Postgres user to connect as. If None, postgres
            defaults to using the current accessing user account name on the
            operating system.
        :type db_user: str | None

        :param db_pass: Password for the user we're connecting as. This may be
            None if no password is to be used.
        :type db_pass: str | None

        :param create_table: If this instance should try to create the storing
            table before actions are performed against it. If the configured
            user does not have sufficient permissions to create the table and it
            does not currently exist, an exception will be raised.
        :type create_table: bool

        """
        super(PostgresDescriptorElement, self).__init__(type_str, uuid)

        self.table_name = table_name
        self.uuid_col = uuid_col
        self.type_col = type_col
        self.binary_col = binary_col
        self.create_table = create_table

        self.db_name = db_name
        self.db_host = db_host
        self.db_port = db_port
        self.db_user = db_user
        self.db_pass = db_pass

        self._psql_helper = None

    def __getstate__(self):
        """
        Construct serialization state.

        Due to the psql_helper containing a lock, it cannot be serialized.  This
        is OK due to our creation of the helper on demand.  The cost incurred by
        discarding the instance upon serialization is that once deserialized
        elsewhere the helper instance will have to be created.  Since this
        creation post-deserialization only happens once, this is acceptable.

        """
        state = super(PostgresDescriptorElement, self).__getstate__()
        state.update({
            "table_name": self.table_name,
            "uuid_col": self.uuid_col,
            "type_col": self.type_col,
            "binary_col": self.binary_col,
            "create_table": self.create_table,
            "db_name": self.db_name,
            "db_host": self.db_host,
            "db_port": self.db_port,
            "db_user": self.db_user,
            "db_pass": self.db_pass,
        })
        return state

    def __setstate__(self, state):
        # Base DescriptorElement parts
        super(PostgresDescriptorElement, self).__setstate__(state)
        # Our parts
        self.table_name = state['table_name']
        self.uuid_col = state['uuid_col']
        self.type_col = state['type_col']
        self.binary_col = state['binary_col']
        self.create_table = state['create_table']
        self.db_name = state['db_name']
        self.db_host = state['db_host']
        self.db_port = state['db_port']
        self.db_user = state['db_user']
        self.db_pass = state['db_pass']
        self._psql_helper = None

    def _get_psql_helper(self):
        """
        Internal method to create on demand the PSQL connection helper class.
        :return: PsqlConnectionHelper utility.
        :rtype: PsqlConnectionHelper
        """
        # `hasattr` check used for backwards compatibility when interacting with
        # databases containing elements serialized before the inclusion of this
        # helper class.
        if self._psql_helper is None:
            # Only using a transport iteration size of 1 since this element is
            # only meant to refer to a single entry in the associated table.
            self._psql_helper = PsqlConnectionHelper(
                self.db_name,
                self.db_host,
                self.db_port,
                self.db_user,
                self.db_pass,
                itersize=1,
                table_upsert_lock=PSQL_TABLE_CREATE_RLOCK)
            # Register table upsert command
            if self.create_table:
                self._psql_helper.set_table_upsert_sql(
                    self.UPSERT_TABLE_TMPL.format(
                        table_name=self.table_name,
                        type_col=self.type_col,
                        uuid_col=self.uuid_col,
                        binary_col=self.binary_col,
                    ))
        return self._psql_helper

    def get_config(self):
        return {
            "table_name": self.table_name,
            "uuid_col": self.uuid_col,
            "type_col": self.type_col,
            "binary_col": self.binary_col,
            "create_table": self.create_table,
            "db_name": self.db_name,
            "db_host": self.db_host,
            "db_port": self.db_port,
            "db_user": self.db_user,
            "db_pass": self.db_pass,
        }

    def has_vector(self):
        """
        Check if the target database has a vector for our keys.

        This also returns True if we have a cached vector since there must have
        been a source vector to draw from if there is a cache of it.

        If a vector is cached, this resets the cache expiry timeout.

        :return: Whether or not this container current has a descriptor vector
            stored.
        :rtype: bool

        """
        # Very similar to vector query, but replacing vector binary return with
        # a true/null return. Save a little bit of time compared to testing
        # vector return.
        # OLD: return self.vector() is not None

        # Using static value 'true' for binary "column" to reduce data return
        # volume.
        q_select = self.SELECT_TMPL.format(
            **{
                'binary_col': 'true',
                'table_name': self.table_name,
                'type_col': self.type_col,
                'uuid_col': self.uuid_col,
            })
        q_select_values = {
            "type_val": self.type(),
            "uuid_val": str(self.uuid())
        }

        def cb(cursor):
            cursor.execute(q_select, q_select_values)

        # Should either yield one or zero rows.
        psql_helper = self._get_psql_helper()
        return bool(
            list(psql_helper.single_execute(cb, yield_result_rows=True)))

    def vector(self):
        """
        Return this element's vector, or None if we don't have one.

        :return: Get the stored descriptor vector as a numpy array. This returns
            None of there is no vector stored in this container.
        :rtype: numpy.ndarray or None

        """
        q_select = self.SELECT_TMPL.format(
            **{
                "binary_col": self.binary_col,
                "table_name": self.table_name,
                "type_col": self.type_col,
                "uuid_col": self.uuid_col,
            })
        q_select_values = {
            "type_val": self.type(),
            "uuid_val": str(self.uuid())
        }

        # query execution callback
        # noinspection PyProtectedMember
        def cb(cursor):
            # type: (psycopg2._psycopg.cursor) -> None
            cursor.execute(q_select, q_select_values)

        # This should only fetch a single row.  Cannot yield more than one due
        # use of primary keys.
        psql_helper = self._get_psql_helper()
        r = list(psql_helper.single_execute(cb, yield_result_rows=True))
        if not r:
            return None
        else:
            b = r[0][0]
            v = numpy.frombuffer(b, self.ARRAY_DTYPE)
            return v

    def set_vector(self, new_vec):
        """
        Set the contained vector.

        If this container already stores a descriptor vector, this will
        overwrite it.

        If we are configured to use caching, and one has not been cached yet,
        then we cache the vector and start a thread to monitor access times and
        to remove the cache if the access timeout has expired.

        If a vector was already cached, this new vector replaces the old one,
        the vector database-side is replaced, and the cache expiry timeout is
        reset.

        :raises ValueError: ``new_vec`` was not a numpy ndarray.

        :param new_vec: New vector to contain. This must be a numpy array.
        :type new_vec: numpy.ndarray

        :returns: Self.
        :rtype: PostgresDescriptorElement

        """
        if not isinstance(new_vec, numpy.ndarray):
            new_vec = numpy.copy(new_vec)

        if new_vec.dtype != self.ARRAY_DTYPE:
            try:
                new_vec = new_vec.astype(self.ARRAY_DTYPE)
            except TypeError:
                raise ValueError("Could not convert input to a vector of type "
                                 "%s." % self.ARRAY_DTYPE)

        q_upsert = self.UPSERT_TMPL.strip().format(
            **{
                "table_name": self.table_name,
                "binary_col": self.binary_col,
                "type_col": self.type_col,
                "uuid_col": self.uuid_col,
            })
        q_upsert_values = {
            "binary_val": psycopg2.Binary(new_vec),
            "type_val": self.type(),
            "uuid_val": str(self.uuid()),
        }

        # query execution callback
        # noinspection PyProtectedMember
        def cb(cursor):
            # type: (psycopg2._psycopg.cursor) -> None
            cursor.execute(q_upsert, q_upsert_values)

        # No return but need to force iteration.
        psql_helper = self._get_psql_helper()
        list(psql_helper.single_execute(cb, yield_result_rows=False))
        return self