class SqlTemplates(object): """ Container for static PostgreSQL queries used by the containing class. """ UPSERT_TABLE_TMPL = norm_psql_cmd_string(""" CREATE TABLE IF NOT EXISTS {table_name:s} ( {key_col:s} BYTEA NOT NULL, {value_col:s} BYTEA NOT NULL, PRIMARY KEY ({key_col:s}) ); """) SELECT_TMPL = norm_psql_cmd_string(""" SELECT {query:s} FROM {table_name:s}; """) SELECT_LIKE_TMPL = norm_psql_cmd_string(""" SELECT {query:s} FROM {table_name:s} WHERE {key_col:s} like %(key_like)s """) UPSERT_TMPL = norm_psql_cmd_string(""" WITH upsert AS ( UPDATE {table_name:s} SET {value_col:s} = %(val)s WHERE {key_col:s} = %(key)s RETURNING * ) INSERT INTO {table_name:s} ({key_col:s}, {value_col:s}) SELECT %(key)s, %(val)s WHERE NOT EXISTS (SELECT * FROM upsert) """) DELETE_ALL = norm_psql_cmd_string(""" DELETE FROM {table_name:s} """)
class SqlTemplates(object): """ Container for static PostgreSQL queries used by the containing class. """ UPSERT_TABLE_TMPL = norm_psql_cmd_string(""" CREATE TABLE IF NOT EXISTS {table_name:s} ( {key_col:s} BYTEA NOT NULL, {value_col:s} BYTEA NOT NULL, PRIMARY KEY ({key_col:s}) ); """) SELECT_TMPL = norm_psql_cmd_string(""" SELECT {query:s} FROM {table_name:s}; """) SELECT_LIKE_TMPL = norm_psql_cmd_string(""" SELECT {query:s} FROM {table_name:s} WHERE {key_col:s} LIKE %(key_like)s """) SELECT_MANY_TMPL = norm_psql_cmd_string(""" SELECT {query:s} FROM {table_name:s} WHERE {key_col:s} IN %(key_tuple)s """) UPSERT_TMPL = norm_psql_cmd_string(""" INSERT INTO {table_name:s} ({key_col:s}, {value_col:s}) VALUES (%(key)s, %(val)s) ON CONFLICT ({key_col:s}) DO UPDATE SET {value_col:s} = EXCLUDED.{value_col:s} """) DELETE_LIKE_TMPL = norm_psql_cmd_string(""" DELETE FROM {table_name:s} WHERE {key_col:s} LIKE %(key_like)s """) DELETE_ALL = norm_psql_cmd_string(""" DELETE FROM {table_name:s} """)
class CommandTemplates(object): """ Encapsulation of command templates. """ # Upsert table for storage if desired # # Format params: # - table_name # - id_col # - sha1_col # - mime_col # - byte_col UPSERT_TABLE = norm_psql_cmd_string(""" CREATE TABLE IF NOT EXISTS {table_name:s} ( {id_col:s} TEXT NOT NULL, {sha1_col:s} TEXT NOT NULL, {mime_col:s} TEXT NOT NULL, {byte_col:s} BYTEA NOT NULL, PRIMARY KEY ({id_col:s}) ); """) # Select ``col`` for a given entry ID. # # Query Format params: # - col # - table_name # - id_col # # Value params: # - id_val SELECT = norm_psql_cmd_string(""" SELECT {col:s} FROM {table_name:s} WHERE {id_col:s} = %(id_val)s ; """) # Upsert content-type/data for a uid # # Query Format params: # - table_name # - id_col # - sha1_col # - mime_col # - byte_col # # Value params: # - id_val # - sha1_val # - mime_val # - byte_val # # SQL format from: # https://hashrocket.com/blog/posts/upsert-records-with-postgresql-9-5 # UPSERT_DATA = norm_psql_cmd_string(""" INSERT INTO {table_name:s} ({id_col:s}, {sha1_col:s}, {mime_col:s}, {byte_col:s}) VALUES ( %(id_val)s, %(sha1_val)s, %(mime_val)s, %(byte_val)s ) ON CONFLICT ({id_col:s}) DO UPDATE SET ({sha1_col:s}, {mime_col:s}, {byte_col:s}) = (EXCLUDED.{sha1_col:s}, EXCLUDED.{mime_col:s}, EXCLUDED.{byte_col:s}) ; """) # Same as ``UPSERT_DATA`` but does not set the mimetype on an update. # This is meant to atomically update the byte data without changing the # existing mimetype. UPSERT_DATA_NO_MIME = norm_psql_cmd_string(""" INSERT INTO {table_name:s} ({id_col:s}, {sha1_col:s}, {mime_col:s}, {byte_col:s}) VALUES ( %(id_val)s, %(sha1_val)s, %(mime_val)s, %(byte_val)s ) ON CONFLICT ({id_col:s}) DO UPDATE SET ({sha1_col:s}, {byte_col:s}) = (EXCLUDED.{sha1_col:s}, EXCLUDED.{byte_col:s}) ; """)
class PostgresDescriptorSet(DescriptorSet): """ DescriptorSet implementation that stored DescriptorElement references in a PostgreSQL database. A ``PostgresDescriptorSet`` effectively controls the entire table. Thus a ``clear()`` call will remove everything from the table. PostgreSQL version support: - 9.4 Table format: <uuid col> TEXT NOT NULL <element col> BYTEA NOT NULL <uuid_col> should be the primary key (we assume unique). We require that the no column labels not be 'true' for the use of a value return shortcut. """ # # The following are SQL query templates. The string formatting using {}'s # is used to fill in the query before using it in an execute with instance # specific values. The ``%()s`` formatting is special for the execute # where-by psycopg2 will fill in the values appropriately as specified in a # second dictionary argument to ``cursor.execute(query, value_dict)``. # UPSERT_TABLE_TMPL = norm_psql_cmd_string(""" CREATE TABLE IF NOT EXISTS {table_name:s} ( {uuid_col:s} TEXT NOT NULL, {element_col:s} BYTEA NOT NULL, PRIMARY KEY ({uuid_col:s}) ); """) SELECT_TMPL = norm_psql_cmd_string(""" SELECT {col:s} FROM {table_name:s} """) SELECT_LIKE_TMPL = norm_psql_cmd_string(""" SELECT {element_col:s} FROM {table_name:s} WHERE {uuid_col:s} like %(uuid_like)s """) # So we can ensure we get back elements in specified order # - reference [1] SELECT_MANY_ORDERED_TMPL = norm_psql_cmd_string(""" SELECT {table_name:s}.{element_col:s} FROM {table_name:s} JOIN ( SELECT * FROM unnest(%(uuid_list)s) with ordinality ) AS __ordering__ ({uuid_col:s}, {uuid_col:s}_order) ON {table_name:s}.{uuid_col:s} = __ordering__.{uuid_col:s} ORDER BY __ordering__.{uuid_col:s}_order """) UPSERT_TMPL = norm_psql_cmd_string(""" WITH upsert AS ( UPDATE {table_name:s} SET {element_col:s} = %(element_val)s WHERE {uuid_col:s} = %(uuid_val)s RETURNING * ) INSERT INTO {table_name:s} ({uuid_col:s}, {element_col:s}) SELECT %(uuid_val)s, %(element_val)s WHERE NOT EXISTS (SELECT * FROM upsert) """) DELETE_LIKE_TMPL = norm_psql_cmd_string(""" DELETE FROM {table_name:s} WHERE {uuid_col:s} like %(uuid_like)s """) DELETE_MANY_TMPL = norm_psql_cmd_string(""" DELETE FROM {table_name:s} WHERE {uuid_col:s} in %(uuid_tuple)s RETURNING uid """) @classmethod def is_usable(cls): return psycopg2 is not None def __init__(self, table_name='descriptor_set', uuid_col='uid', element_col='element', db_name='postgres', db_host=None, db_port=None, db_user=None, db_pass=None, multiquery_batch_size=1000, pickle_protocol=-1, read_only=False, create_table=True): """ Initialize set instance. :param table_name: Name of the table to use. :type table_name: str :param uuid_col: Name of the column containing the UUID signatures. :type uuid_col: str :param element_col: Name of the table column that will contain serialized elements. :type element_col: str :param db_name: The name of the database to connect to. :type db_name: str :param db_host: Host address of the Postgres server. If None, we assume the server is on the local machine and use the UNIX socket. This might be a required field on Windows machines (not tested yet). :type db_host: str | None :param db_port: Port the Postgres server is exposed on. If None, we assume the default port (5423). :type db_port: int | None :param db_user: Postgres user to connect as. If None, postgres defaults to using the current accessing user account name on the operating system. :type db_user: str | None :param db_pass: Password for the user we're connecting as. This may be None if no password is to be used. :type db_pass: str | None :param multiquery_batch_size: For queries that handle sending or receiving many queries at a time, batch queries based on this size. If this is None, then no batching occurs. The advantage of batching is that it reduces the memory impact for queries dealing with a very large number of elements (don't have to store the full query for all elements in RAM), but the transaction will be some amount slower due to splitting the query into multiple transactions. :type multiquery_batch_size: int | None :param pickle_protocol: Pickling protocol to use. We will use -1 by default (latest version, probably binary). :type pickle_protocol: int :param read_only: Only allow read actions against this set. Modification actions will throw a ReadOnlyError exceptions. :type read_only: bool :param create_table: If this instance should try to create the storing table before actions are performed against it when not set to be read-only. If the configured user does not have sufficient permissions to create the table and it does not currently exist, an exception will be raised. :type create_table: bool """ super(PostgresDescriptorSet, self).__init__() self.table_name = table_name self.uuid_col = uuid_col self.element_col = element_col self.multiquery_batch_size = multiquery_batch_size self.pickle_protocol = pickle_protocol self.read_only = bool(read_only) self.create_table = create_table # Checking parameters where necessary if self.multiquery_batch_size is not None: self.multiquery_batch_size = int(self.multiquery_batch_size) assert self.multiquery_batch_size > 0, \ "A given batch size must be greater than 0 in size " \ "(given: %d)." % self.multiquery_batch_size assert -1 <= self.pickle_protocol <= 2, \ ("Given pickle protocol is not in the known valid range. Given: %s" % self.pickle_protocol) self.psql_helper = PsqlConnectionHelper(db_name, db_host, db_port, db_user, db_pass, self.multiquery_batch_size, PSQL_TABLE_CREATE_RLOCK) if not self.read_only and self.create_table: self.psql_helper.set_table_upsert_sql( self.UPSERT_TABLE_TMPL.format( table_name=self.table_name, uuid_col=self.uuid_col, element_col=self.element_col, )) def get_config(self): return { "table_name": self.table_name, "uuid_col": self.uuid_col, "element_col": self.element_col, "db_name": self.psql_helper.db_name, "db_host": self.psql_helper.db_host, "db_port": self.psql_helper.db_port, "db_user": self.psql_helper.db_user, "db_pass": self.psql_helper.db_pass, "multiquery_batch_size": self.multiquery_batch_size, "pickle_protocol": self.pickle_protocol, "read_only": self.read_only, "create_table": self.create_table, } def count(self): """ :return: Number of descriptor elements stored in this set. :rtype: int | long """ # Just count UUID column to limit data read. q = self.SELECT_TMPL.format( col='count(%s)' % self.uuid_col, table_name=self.table_name, ) def exec_hook(cur): cur.execute(q) # There's only going to be one row returned with one element in it. return list( self.psql_helper.single_execute(exec_hook, yield_result_rows=True))[0][0] def clear(self): """ Clear this descriptor set's entries. """ if self.read_only: raise ReadOnlyError("Cannot clear a read-only set.") q = self.DELETE_LIKE_TMPL.format( table_name=self.table_name, uuid_col=self.uuid_col, ) def exec_hook(cur): cur.execute(q, {'uuid_like': '%'}) list(self.psql_helper.single_execute(exec_hook)) def has_descriptor(self, uuid): """ Check if a DescriptorElement with the given UUID exists in this set. :param uuid: UUID to query for :type uuid: collections.Hashable :return: True if a DescriptorElement with the given UUID exists in this set, or False if not. :rtype: bool """ q = self.SELECT_LIKE_TMPL.format( # hacking return value to something simple element_col='true', table_name=self.table_name, uuid_col=self.uuid_col, ) def exec_hook(cur): cur.execute(q, {'uuid_like': str(uuid)}) # Should either yield one or zero rows return bool( list( self.psql_helper.single_execute(exec_hook, yield_result_rows=True))) def add_descriptor(self, descriptor): """ Add a descriptor to this set. Adding the same descriptor multiple times should not add multiple copies of the descriptor in the set (based on UUID). Added descriptors overwrite set descriptors based on UUID. :param descriptor: Descriptor to set. :type descriptor: smqtk.representation.DescriptorElement """ if self.read_only: raise ReadOnlyError("Cannot clear a read-only set.") q = self.UPSERT_TMPL.format( table_name=self.table_name, uuid_col=self.uuid_col, element_col=self.element_col, ) v = { 'uuid_val': str(descriptor.uuid()), 'element_val': psycopg2.Binary(pickle.dumps(descriptor, self.pickle_protocol)) } def exec_hook(cur): cur.execute(q, v) list(self.psql_helper.single_execute(exec_hook)) def add_many_descriptors(self, descriptors): """ Add multiple descriptors at one time. Adding the same descriptor multiple times should not add multiple copies of the descriptor in the set (based on UUID). Added descriptors overwrite set descriptors based on UUID. :param descriptors: Iterable of descriptor instances to add to this set. :type descriptors: collections.Iterable[smqtk.representation.DescriptorElement] """ if self.read_only: raise ReadOnlyError("Cannot clear a read-only set.") q = self.UPSERT_TMPL.format( table_name=self.table_name, uuid_col=self.uuid_col, element_col=self.element_col, ) # Transform input into def iter_elements(): for d in descriptors: yield { 'uuid_val': str(d.uuid()), 'element_val': psycopg2.Binary(pickle.dumps(d, self.pickle_protocol)) } def exec_hook(cur, batch): cur.executemany(q, batch) self._log.debug("Adding many descriptors") list( self.psql_helper.batch_execute(iter_elements(), exec_hook, self.multiquery_batch_size)) def get_descriptor(self, uuid): """ Get the descriptor in this set that is associated with the given UUID. :param uuid: UUID of the DescriptorElement to get. :type uuid: collections.Hashable :raises KeyError: The given UUID doesn't associate to a DescriptorElement in this set. :return: DescriptorElement associated with the queried UUID. :rtype: smqtk.representation.DescriptorElement """ q = self.SELECT_LIKE_TMPL.format( element_col=self.element_col, table_name=self.table_name, uuid_col=self.uuid_col, ) v = {'uuid_like': str(uuid)} def eh(c): c.execute(q, v) if c.rowcount == 0: raise KeyError(uuid) elif c.rowcount != 1: raise RuntimeError("Found more than one entry for the given " "uuid '%s' (got: %d)" % (uuid, c.rowcount)) r = list(self.psql_helper.single_execute(eh, yield_result_rows=True)) return pickle.loads(bytes(r[0][0])) def get_many_descriptors(self, uuids): """ Get an iterator over descriptors associated to given descriptor UUIDs. :param uuids: Iterable of descriptor UUIDs to query for. :type uuids: collections.Iterable[collections.Hashable] :raises KeyError: A given UUID doesn't associate with a DescriptorElement in this set. :return: Iterator of descriptors associated to given uuid values. :rtype: __generator[smqtk.representation.DescriptorElement] """ q = self.SELECT_MANY_ORDERED_TMPL.format( table_name=self.table_name, element_col=self.element_col, uuid_col=self.uuid_col, ) # Cache UUIDs received in order so we can check when we miss one in # order to raise a KeyError. uuid_order = [] def iterelems(): for uid in uuids: uuid_order.append(uid) yield str(uid) def exec_hook(cur, batch): v = {'uuid_list': batch} # self._log.debug('query: %s', cur.mogrify(q, v)) cur.execute(q, v) self._log.debug("Getting many descriptors") # The SELECT_MANY_ORDERED_TMPL query ensures that elements returned are # in the UUID order given to this method. Thus, if the iterated UUIDs # and iterated return rows do not exactly line up, the query join # failed to match a query UUID to something in the database. # - We also check that the number of rows we got back is the same # as elements yielded, else there were trailing UUIDs that did not # match anything in the database. g = self.psql_helper.batch_execute(iterelems(), exec_hook, self.multiquery_batch_size, yield_result_rows=True) i = 0 for r, expected_uuid in zip(g, uuid_order): d = pickle.loads(bytes(r[0])) if d.uuid() != expected_uuid: raise KeyError(expected_uuid) yield d i += 1 if len(uuid_order) != i: # just report the first one that's bad raise KeyError(uuid_order[i]) def remove_descriptor(self, uuid): """ Remove a descriptor from this set by the given UUID. :param uuid: UUID of the DescriptorElement to remove. :type uuid: collections.Hashable :raises KeyError: The given UUID doesn't associate to a DescriptorElement in this set. """ if self.read_only: raise ReadOnlyError("Cannot remove from a read-only set.") q = self.DELETE_LIKE_TMPL.format( table_name=self.table_name, uuid_col=self.uuid_col, ) v = {'uuid_like': str(uuid)} def execute(c): c.execute(q, v) # Nothing deleted if rowcount == 0 # (otherwise 1 when deleted a thing) if c.rowcount == 0: raise KeyError(uuid) list(self.psql_helper.single_execute(execute)) def remove_many_descriptors(self, uuids): """ Remove descriptors associated to given descriptor UUIDs from this set. :param uuids: Iterable of descriptor UUIDs to remove. :type uuids: collections.Iterable[collections.Hashable] :raises KeyError: A given UUID doesn't associate with a DescriptorElement in this set. """ if self.read_only: raise ReadOnlyError("Cannot remove from a read-only set.") q = self.DELETE_MANY_TMPL.format( table_name=self.table_name, uuid_col=self.uuid_col, ) str_uuid_set = set(str(uid) for uid in uuids) v = {'uuid_tuple': tuple(str_uuid_set)} def execute(c): c.execute(q, v) # Check query UUIDs against rows that would actually be deleted. deleted_uuid_set = set(r[0] for r in c.fetchall()) for uid in str_uuid_set: if uid not in deleted_uuid_set: raise KeyError(uid) list(self.psql_helper.single_execute(execute)) def iterkeys(self): """ Return an iterator over set descriptor keys, which are their UUIDs. :rtype: collections.Iterator[collections.Hashable] """ # Getting UUID through the element because the UUID might not be a # string type, and the true type is encoded with the DescriptorElement # instance. for d in self.iterdescriptors(): yield d.uuid() def iterdescriptors(self): """ Return an iterator over set descriptor element instances. :rtype: collections.Iterator[smqtk.representation.DescriptorElement] """ def execute(c): c.execute( self.SELECT_TMPL.format(col=self.element_col, table_name=self.table_name)) #: :type: __generator execution_results = self.psql_helper.single_execute( execute, yield_result_rows=True, named=True) for r in execution_results: d = pickle.loads(bytes(r[0])) yield d def iteritems(self): """ Return an iterator over set descriptor key and instance pairs. :rtype: collections.Iterator[(collections.Hashable, smqtk.representation.DescriptorElement)] """ for d in self.iterdescriptors(): yield d.uuid(), d
class PostgresDescriptorElement(DescriptorElement): """ Descriptor element whose vector is stored in a Postgres database. We assume we will work with a Postgres version of at least 9.4 (due to versions tested). Efficient connection pooling may be achieved via external utilities like PGBounder. """ ARRAY_DTYPE = numpy.float64 UPSERT_TABLE_TMPL = norm_psql_cmd_string(""" CREATE TABLE IF NOT EXISTS {table_name:s} ( {type_col:s} TEXT NOT NULL, {uuid_col:s} TEXT NOT NULL, {binary_col:s} BYTEA NOT NULL, PRIMARY KEY ({type_col:s}, {uuid_col:s}) ); """) SELECT_TMPL = norm_psql_cmd_string(""" SELECT {binary_col:s} FROM {table_name:s} WHERE {type_col:s} = %(type_val)s AND {uuid_col:s} = %(uuid_val)s ; """) UPSERT_TMPL = norm_psql_cmd_string(""" WITH upsert AS ( UPDATE {table_name:s} SET {binary_col:s} = %(binary_val)s WHERE {type_col:s} = %(type_val)s AND {uuid_col:s} = %(uuid_val)s RETURNING * ) INSERT INTO {table_name:s} ({type_col:s}, {uuid_col:s}, {binary_col:s}) SELECT %(type_val)s, %(uuid_val)s, %(binary_val)s WHERE NOT EXISTS (SELECT * FROM upsert); """) @classmethod def is_usable(cls): if psycopg2 is None: cls.get_logger().warning("Not usable. Requires psycopg2 module") return False return True def __init__(self, type_str, uuid, table_name='descriptors', uuid_col='uid', type_col='type_str', binary_col='vector', db_name='postgres', db_host=None, db_port=None, db_user=None, db_pass=None, create_table=True): """ Initialize new PostgresDescriptorElement attached to some database credentials. We require that storage tables treat uuid AND type string columns as primary keys. The type and uuid columns should be of the 'text' type. The binary column should be of the 'bytea' type. Default argument values assume a local PostgreSQL database with a table created via the ``etc/smqtk/postgres/descriptor_element/example_table_init.sql`` file (relative to the SMQTK source tree or install root). NOTES: - Not all uuid types used here are necessarily of the ``uuid.UUID`` type, thus the recommendation to use a ``text`` type for the column. For certain specific use cases they may be proper ``uuid.UUID`` instances or strings, but this cannot be generally assumed. :param type_str: Type of descriptor. This is usually the name of the content descriptor that generated this vector. :type type_str: str :param uuid: Unique ID reference of the descriptor. :type uuid: collections.Hashable :param table_name: String label of the database table to use. :type table_name: str :param uuid_col: The column label for descriptor UUID storage :type uuid_col: str :param type_col: The column label for descriptor type string storage. :type type_col: str :param binary_col: The column label for descriptor vector binary storage. :type binary_col: str :param db_host: Host address of the Postgres server. If None, we assume the server is on the local machine and use the UNIX socket. This might be a required field on Windows machines (not tested yet). :type db_host: str | None :param db_port: Port the Postgres server is exposed on. If None, we assume the default port (5423). :type db_port: int | None :param db_name: The name of the database to connect to. :type db_name: str :param db_user: Postgres user to connect as. If None, postgres defaults to using the current accessing user account name on the operating system. :type db_user: str | None :param db_pass: Password for the user we're connecting as. This may be None if no password is to be used. :type db_pass: str | None :param create_table: If this instance should try to create the storing table before actions are performed against it. If the configured user does not have sufficient permissions to create the table and it does not currently exist, an exception will be raised. :type create_table: bool """ super(PostgresDescriptorElement, self).__init__(type_str, uuid) self.table_name = table_name self.uuid_col = uuid_col self.type_col = type_col self.binary_col = binary_col self.create_table = create_table self.db_name = db_name self.db_host = db_host self.db_port = db_port self.db_user = db_user self.db_pass = db_pass def get_config(self): return { "table_name": self.table_name, "uuid_col": self.uuid_col, "type_col": self.type_col, "binary_col": self.binary_col, "create_table": self.create_table, "db_name": self.db_name, "db_host": self.db_host, "db_port": self.db_port, "db_user": self.db_user, "db_pass": self.db_pass, } def _get_psql_connection(self): """ :return: A new connection to the configured database :rtype: psycopg2._psycopg.connection """ return psycopg2.connect( database=self.db_name, user=self.db_user, password=self.db_pass, host=self.db_host, port=self.db_port, ) def _ensure_table(self, cursor): """ Execute on psql connector cursor the table create-of-not-exists query. :param cursor: Connection active cursor. """ if self.create_table: q_table_upsert = self.UPSERT_TABLE_TMPL.format(**dict( table_name=self.table_name, type_col=self.type_col, uuid_col=self.uuid_col, binary_col=self.binary_col, )) with PSQL_TABLE_CREATE_RLOCK: cursor.execute(q_table_upsert) cursor.connection.commit() def has_vector(self): """ Check if the target database has a vector for our keys. This also returns True if we have a cached vector since there must have been a source vector to draw from if there is a cache of it. If a vector is cached, this resets the cache expiry timeout. :return: Whether or not this container current has a descriptor vector stored. :rtype: bool """ # Very similar to vector query, but replacing vector binary return with # a true/null return. Save a little bit of time compared to testing # vector return. # OLD: return self.vector() is not None # Using static value 'true' for binary "column" to reduce data return # volume. q_select = self.SELECT_TMPL.format( **{ 'binary_col': 'true', 'table_name': self.table_name, 'type_col': self.type_col, 'uuid_col': self.uuid_col, }) q_select_values = { "type_val": self.type(), "uuid_val": str(self.uuid()) } conn = self._get_psql_connection() cur = conn.cursor() try: self._ensure_table(cur) cur.execute(q_select, q_select_values) r = cur.fetchone() # For server cleaning (e.g. pgbouncer) conn.commit() return bool(r) except: conn.rollback() raise finally: cur.close() conn.close() def vector(self): """ Return this element's vector, or None if we don't have one. :return: Get the stored descriptor vector as a numpy array. This returns None of there is no vector stored in this container. :rtype: numpy.core.multiarray.ndarray or None """ q_select = self.SELECT_TMPL.format( **{ "binary_col": self.binary_col, "table_name": self.table_name, "type_col": self.type_col, "uuid_col": self.uuid_col, }) q_select_values = { "type_val": self.type(), "uuid_val": str(self.uuid()) } conn = self._get_psql_connection() cur = conn.cursor() try: self._ensure_table(cur) cur.execute(q_select, q_select_values) r = cur.fetchone() conn.commit() if not r: return None else: b = r[0] v = numpy.frombuffer(b, self.ARRAY_DTYPE) return v except: conn.rollback() raise finally: cur.close() conn.close() def set_vector(self, new_vec): """ Set the contained vector. If this container already stores a descriptor vector, this will overwrite it. If we are configured to use caching, and one has not been cached yet, then we cache the vector and start a thread to monitor access times and to remove the cache if the access timeout has expired. If a vector was already cached, this new vector replaces the old one, the vector database-side is replaced, and the cache expiry timeout is reset. :raises ValueError: ``new_vec`` was not a numpy ndarray. :param new_vec: New vector to contain. This must be a numpy array. :type new_vec: numpy.core.multiarray.ndarray """ if not isinstance(new_vec, numpy.core.multiarray.ndarray): raise ValueError( "Input array for setting was not a numpy.ndarray! " "(given: %s)" % type(new_vec)) if new_vec.dtype != self.ARRAY_DTYPE: new_vec = new_vec.astype(self.ARRAY_DTYPE) q_upsert = self.UPSERT_TMPL.strip().format( **{ "table_name": self.table_name, "binary_col": self.binary_col, "type_col": self.type_col, "uuid_col": self.uuid_col, }) q_upsert_values = { "binary_val": psycopg2.Binary(new_vec), "type_val": self.type(), "uuid_val": str(self.uuid()), } conn = self._get_psql_connection() cur = conn.cursor() try: self._ensure_table(cur) cur.execute(q_upsert, q_upsert_values) conn.commit() except: conn.rollback() raise finally: cur.close() conn.close()
class PostgresDescriptorElement(DescriptorElement): """ Descriptor element whose vector is stored in a Postgres database. We assume we will work with a Postgres version of at least 9.4 (due to versions tested). Efficient connection pooling may be achieved via external utilities like PGBounder. """ ARRAY_DTYPE = numpy.float64 UPSERT_TABLE_TMPL = norm_psql_cmd_string(""" CREATE TABLE IF NOT EXISTS {table_name:s} ( {type_col:s} TEXT NOT NULL, {uuid_col:s} TEXT NOT NULL, {binary_col:s} BYTEA NOT NULL, PRIMARY KEY ({type_col:s}, {uuid_col:s}) ); """) SELECT_TMPL = norm_psql_cmd_string(""" SELECT {binary_col:s} FROM {table_name:s} WHERE {type_col:s} = %(type_val)s AND {uuid_col:s} = %(uuid_val)s ; """) UPSERT_TMPL = norm_psql_cmd_string(""" WITH upsert AS ( UPDATE {table_name:s} SET {binary_col:s} = %(binary_val)s WHERE {type_col:s} = %(type_val)s AND {uuid_col:s} = %(uuid_val)s RETURNING * ) INSERT INTO {table_name:s} ({type_col:s}, {uuid_col:s}, {binary_col:s}) SELECT %(type_val)s, %(uuid_val)s, %(binary_val)s WHERE NOT EXISTS (SELECT * FROM upsert); """) @classmethod def is_usable(cls): if psycopg2 is None: cls.get_logger().warning("Not usable. Requires psycopg2 module") return False return True def __init__(self, type_str, uuid, table_name='descriptors', uuid_col='uid', type_col='type_str', binary_col='vector', db_name='postgres', db_host=None, db_port=None, db_user=None, db_pass=None, create_table=True): """ Initialize new PostgresDescriptorElement attached to some database credentials. We require that storage tables treat uuid AND type string columns as primary keys. The type and uuid columns should be of the 'text' type. The binary column should be of the 'bytea' type. Default argument values assume a local PostgreSQL database with a table created via the ``etc/smqtk/postgres/descriptor_element/example_table_init.sql`` file (relative to the SMQTK source tree or install root). NOTES: - Not all uuid types used here are necessarily of the ``uuid.UUID`` type, thus the recommendation to use a ``text`` type for the column. For certain specific use cases they may be proper ``uuid.UUID`` instances or strings, but this cannot be generally assumed. :param type_str: Type of descriptor. This is usually the name of the content descriptor that generated this vector. :type type_str: str :param uuid: Unique ID reference of the descriptor. :type uuid: collections.Hashable :param table_name: String label of the database table to use. :type table_name: str :param uuid_col: The column label for descriptor UUID storage :type uuid_col: str :param type_col: The column label for descriptor type string storage. :type type_col: str :param binary_col: The column label for descriptor vector binary storage. :type binary_col: str :param db_host: Host address of the Postgres server. If None, we assume the server is on the local machine and use the UNIX socket. This might be a required field on Windows machines (not tested yet). :type db_host: str | None :param db_port: Port the Postgres server is exposed on. If None, we assume the default port (5423). :type db_port: int | None :param db_name: The name of the database to connect to. :type db_name: str :param db_user: Postgres user to connect as. If None, postgres defaults to using the current accessing user account name on the operating system. :type db_user: str | None :param db_pass: Password for the user we're connecting as. This may be None if no password is to be used. :type db_pass: str | None :param create_table: If this instance should try to create the storing table before actions are performed against it. If the configured user does not have sufficient permissions to create the table and it does not currently exist, an exception will be raised. :type create_table: bool """ super(PostgresDescriptorElement, self).__init__(type_str, uuid) self.table_name = table_name self.uuid_col = uuid_col self.type_col = type_col self.binary_col = binary_col self.create_table = create_table self.db_name = db_name self.db_host = db_host self.db_port = db_port self.db_user = db_user self.db_pass = db_pass self._psql_helper = None def __getstate__(self): """ Construct serialization state. Due to the psql_helper containing a lock, it cannot be serialized. This is OK due to our creation of the helper on demand. The cost incurred by discarding the instance upon serialization is that once deserialized elsewhere the helper instance will have to be created. Since this creation post-deserialization only happens once, this is acceptable. """ state = super(PostgresDescriptorElement, self).__getstate__() state.update({ "table_name": self.table_name, "uuid_col": self.uuid_col, "type_col": self.type_col, "binary_col": self.binary_col, "create_table": self.create_table, "db_name": self.db_name, "db_host": self.db_host, "db_port": self.db_port, "db_user": self.db_user, "db_pass": self.db_pass, }) return state def __setstate__(self, state): # Base DescriptorElement parts super(PostgresDescriptorElement, self).__setstate__(state) # Our parts self.table_name = state['table_name'] self.uuid_col = state['uuid_col'] self.type_col = state['type_col'] self.binary_col = state['binary_col'] self.create_table = state['create_table'] self.db_name = state['db_name'] self.db_host = state['db_host'] self.db_port = state['db_port'] self.db_user = state['db_user'] self.db_pass = state['db_pass'] self._psql_helper = None def _get_psql_helper(self): """ Internal method to create on demand the PSQL connection helper class. :return: PsqlConnectionHelper utility. :rtype: PsqlConnectionHelper """ # `hasattr` check used for backwards compatibility when interacting with # databases containing elements serialized before the inclusion of this # helper class. if self._psql_helper is None: # Only using a transport iteration size of 1 since this element is # only meant to refer to a single entry in the associated table. self._psql_helper = PsqlConnectionHelper( self.db_name, self.db_host, self.db_port, self.db_user, self.db_pass, itersize=1, table_upsert_lock=PSQL_TABLE_CREATE_RLOCK) # Register table upsert command if self.create_table: self._psql_helper.set_table_upsert_sql( self.UPSERT_TABLE_TMPL.format( table_name=self.table_name, type_col=self.type_col, uuid_col=self.uuid_col, binary_col=self.binary_col, )) return self._psql_helper def get_config(self): return { "table_name": self.table_name, "uuid_col": self.uuid_col, "type_col": self.type_col, "binary_col": self.binary_col, "create_table": self.create_table, "db_name": self.db_name, "db_host": self.db_host, "db_port": self.db_port, "db_user": self.db_user, "db_pass": self.db_pass, } def has_vector(self): """ Check if the target database has a vector for our keys. This also returns True if we have a cached vector since there must have been a source vector to draw from if there is a cache of it. If a vector is cached, this resets the cache expiry timeout. :return: Whether or not this container current has a descriptor vector stored. :rtype: bool """ # Very similar to vector query, but replacing vector binary return with # a true/null return. Save a little bit of time compared to testing # vector return. # OLD: return self.vector() is not None # Using static value 'true' for binary "column" to reduce data return # volume. q_select = self.SELECT_TMPL.format( **{ 'binary_col': 'true', 'table_name': self.table_name, 'type_col': self.type_col, 'uuid_col': self.uuid_col, }) q_select_values = { "type_val": self.type(), "uuid_val": str(self.uuid()) } def cb(cursor): cursor.execute(q_select, q_select_values) # Should either yield one or zero rows. psql_helper = self._get_psql_helper() return bool( list(psql_helper.single_execute(cb, yield_result_rows=True))) def vector(self): """ Return this element's vector, or None if we don't have one. :return: Get the stored descriptor vector as a numpy array. This returns None of there is no vector stored in this container. :rtype: numpy.ndarray or None """ q_select = self.SELECT_TMPL.format( **{ "binary_col": self.binary_col, "table_name": self.table_name, "type_col": self.type_col, "uuid_col": self.uuid_col, }) q_select_values = { "type_val": self.type(), "uuid_val": str(self.uuid()) } # query execution callback # noinspection PyProtectedMember def cb(cursor): # type: (psycopg2._psycopg.cursor) -> None cursor.execute(q_select, q_select_values) # This should only fetch a single row. Cannot yield more than one due # use of primary keys. psql_helper = self._get_psql_helper() r = list(psql_helper.single_execute(cb, yield_result_rows=True)) if not r: return None else: b = r[0][0] v = numpy.frombuffer(b, self.ARRAY_DTYPE) return v def set_vector(self, new_vec): """ Set the contained vector. If this container already stores a descriptor vector, this will overwrite it. If we are configured to use caching, and one has not been cached yet, then we cache the vector and start a thread to monitor access times and to remove the cache if the access timeout has expired. If a vector was already cached, this new vector replaces the old one, the vector database-side is replaced, and the cache expiry timeout is reset. :raises ValueError: ``new_vec`` was not a numpy ndarray. :param new_vec: New vector to contain. This must be a numpy array. :type new_vec: numpy.ndarray :returns: Self. :rtype: PostgresDescriptorElement """ if not isinstance(new_vec, numpy.ndarray): new_vec = numpy.copy(new_vec) if new_vec.dtype != self.ARRAY_DTYPE: try: new_vec = new_vec.astype(self.ARRAY_DTYPE) except TypeError: raise ValueError("Could not convert input to a vector of type " "%s." % self.ARRAY_DTYPE) q_upsert = self.UPSERT_TMPL.strip().format( **{ "table_name": self.table_name, "binary_col": self.binary_col, "type_col": self.type_col, "uuid_col": self.uuid_col, }) q_upsert_values = { "binary_val": psycopg2.Binary(new_vec), "type_val": self.type(), "uuid_val": str(self.uuid()), } # query execution callback # noinspection PyProtectedMember def cb(cursor): # type: (psycopg2._psycopg.cursor) -> None cursor.execute(q_upsert, q_upsert_values) # No return but need to force iteration. psql_helper = self._get_psql_helper() list(psql_helper.single_execute(cb, yield_result_rows=False)) return self