def test_context_manager(self): """ test using the context manager to access the pool """ min_connections = 1 max_connections = 5 test_number = 42 connection_pool = ThreadedConnectionPool(min_connections, max_connections, **_database_credentials) test_greenlet = ContextWriteGreenlet(connection_pool, test_number, 3.0) rollback_greenlet = ContextRollbackGreenlet(connection_pool, 3.0) test_greenlet.start() rollback_greenlet.start() test_greenlet.join() self.assertTrue(test_greenlet.successful()) rollback_greenlet.join() self.assertTrue(rollback_greenlet.successful()) result = test_greenlet.value self.assertEqual(result, [(test_number, )]) connection_pool.closeall()
class PgConnectionPool: def __init__(self, *args, min_conns=1, keep_conns=10, max_conns=10, **kwargs): self._pool = ThreadedConnectionPool(min_conns, max_conns, *args, **kwargs) self._keep_conns = keep_conns def acquire(self): pool = self._pool conn = pool.getconn() pool.minconn = min(self._keep_conns, len(pool._used)) return conn def release(self, conn): self._pool.putconn(conn) def close(self): if hasattr(self, "_pool"): self._pool.closeall() __del__ = close
def test_decorator(self): """ test using the decorator to access the pool """ global _connection_pool min_connections = 1 max_connections = 5 test_number = 42 _connection_pool = ThreadedConnectionPool(min_connections, max_connections, **_database_credentials) test_greenlet = DecoratorWriteGreenlet(test_number, 3.0) rollback_greenlet = DecoratorRollbackGreenlet(3.0) test_greenlet.start() rollback_greenlet.start() test_greenlet.join() self.assertTrue(test_greenlet.successful()) rollback_greenlet.join() self.assertTrue(rollback_greenlet.successful()) result = test_greenlet.value self.assertEqual(result, [(test_number, )]) _connection_pool.closeall()
class PostgresqlConnection: def __init__(self, **kwargs): self.pool = ThreadedConnectionPool(minconn=kwargs['minconn'], maxconn=kwargs['maxconn'], database=kwargs['database'], user=kwargs['username'], password=kwargs['password'], host=kwargs['host'], port=kwargs['port']) @contextmanager def cursor(self, auto_commit=True): conn = self.pool.getconn() cursor = conn.cursor() try: yield cursor finally: if auto_commit: conn.commit() if cursor is not None and not cursor.closed: cursor.close() self.pool.putconn(conn) def close(self): self.pool.closeall()
class ConnectionPool: # no test coverage """https://gist.github.com/jeorgen/4eea9b9211bafeb18ada""" is_setup = False def setup(self): self.last_seen_process_id = os.getpid() self._init() self.is_setup = True def _init(self): self._pool = ThreadedConnectionPool(1, 10, database=config.db_name, **credentials) def _getconn(self) -> connection: current_pid = os.getpid() if not (current_pid == self.last_seen_process_id): self._init() log.debug( f"New id is {current_pid}, old id was {self.last_seen_process_id}" ) self.last_seen_process_id = current_pid conn = self._pool.getconn() return conn def _putconn(self, conn: connection): return self._pool.putconn(conn) def closeall(self): self._pool.closeall() @contextmanager def get_connection(self) -> t.Generator[connection, None, None]: try: conn = self._getconn() yield conn finally: self._putconn(conn) @contextmanager def get_cursor(self, commit=False) -> t.Generator[LoggingCursor, None, None]: with self.get_connection() as conn: cursor = conn.cursor(cursor_factory=LoggingCursor) try: yield cursor if commit: conn.commit() finally: cursor.close()
class DBSaver: def __init__(self): self.pool = ThreadedConnectionPool(minconn=1, maxconn=5, user="******", password="******", host="localhost", port="5432", database="irdb") def get_saver_for_id(self, sensor_id): return DBSingleSensorSaver(sensor_id, self) def close(self): self.pool.closeall()
class QuastTestCase(unittest.TestCase): def setUp(self): self.createdb('quast_test') # Run setup.sql file_path = os.path.abspath( os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'sql/setup.sql')) self.conn = psycopg2.connect("user=postgres dbname=quast_test") self.conn.autocommit = True with open(file_path) as f: with self.conn.cursor() as curs: sql = f.read() curs.execute(sql) # Run populate.sql file_path = os.path.abspath( os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'sql/populate.sql')) with open(file_path) as f: with self.conn.cursor() as curs: sql = f.read() curs.execute(sql) self.pool = ThreadedConnectionPool(1, 10, "user=postgres dbname=quast_test") @staticmethod def createdb(name): conn = psycopg2.connect("user=postgres dbname=postgres") conn.autocommit = True with conn.cursor() as curs: curs.execute("CREATE DATABASE {};".format(name)) @staticmethod def dropdb(name): conn = psycopg2.connect("user=postgres dbname=postgres") conn.autocommit = True with conn.cursor() as curs: curs.execute("DROP DATABASE {};".format(name)) def tearDown(self): self.pool.closeall() self.conn.close() self.dropdb('quast_test')
class DBPool(object): """pg数据库连接池""" _instance_lock = threading.Lock() def __init__(self): # pg数据库连接池 logger.debug(">>>>>>pg_pool start create") self.pg_pool = ThreadedConnectionPool(2, 5, host=PG.host, port=PG.port, database=PG.name, user=PG.user, password=PG.pwd) logger.debug(">>>>>>pg_pool create success") # redis连接池 logger.debug(">>>>>>redis_pool start create") self.redis_pool = redis.ConnectionPool(host=REDIS.host, port=REDIS.port, password=REDIS.pwd) logger.debug(">>>>>>redis_pool create success") def __new__(cls): if not hasattr(DBPool, "_instance"): with DBPool._instance_lock: if not hasattr(DBPool, "_instance"): DBPool._instance = object.__new__(cls) return DBPool._instance def __del__(self): if hasattr(self, "pg_pool") and self.pg_pool: self.pg_pool.closeall() del self.pg_pool self.pg_pool = None if hasattr(self, "redis_pool") and self.redis_pool: self.redis_pool.disconnect() del self.redis_pool self.redis_pool = None
class PgConnectionPool: def __init__(self, *args, min_conns=1, keep_conns=10, max_conns=10, **kwargs): self._pool = ThreadedConnectionPool( min_conns, max_conns, *args, **kwargs) self._keep_conns = keep_conns def acquire(self): pool = self._pool conn = pool.getconn() pool.minconn = min(self._keep_conns, len(pool._used)) return conn def release(self, conn): self._pool.putconn(conn) def close(self): if hasattr(self, '_pool'): self._pool.closeall() __del__ = close
class PostgreSQL: """ PostgreSQL Wrapper Maintain a thread-safe connection pool and provide convenience methods for downstream, with dotted access to cursor results """ _pool: AbstractConnectionPool def __init__(self, *, database: str, user: str, password: str, host: str, port: int = 5432) -> None: dsn = f"dbname={database} user={user} password={password} host={host} port={port}" self._pool = ThreadedConnectionPool(_POOL_MIN, _POOL_MAX, dsn, cursor_factory=NamedTupleCursor) def __del__(self) -> None: self._pool.closeall() def transaction(self, autocommit: bool = False) -> Transaction: return Transaction(pool=self._pool, autocommit=autocommit) def execute_script(self, sql: T.Path) -> None: """ Execute the given SQL script against the database @param sql Path to SQL script """ with self.transaction(autocommit=True) as t: t.execute(sql.read_text())
class PgPool: def __init__(self): logger.debug('initializing postgres threaded pool') self.host, self.port = None, None self.database, self.pool = None, None self.user, self.passwd = None, None self.pool = None logger.debug('Server Addr: {host}:{port}; Database: {db}; User: {user}; Password: {passwd}'.format( host=self.host, port=self.port, db=self.database, user=self.user, passwd=self.passwd )) def create_pool(self, conn_dict, limits): """ Create a connection pool :param conn_dict: connection params dictionary :type conn_dict: dict """ if conn_dict["Host"] is None: self.host = 'localhost' else: self.host = conn_dict["Host"] if conn_dict["Port"] is None: self.port = '5432' else: self.port = conn_dict["Port"] self.database = conn_dict["Database"] self.user = conn_dict["User"] self.passwd = conn_dict["Password"] conn_params = "host='{host}' dbname='{db}' user='******' password='******' port='{port}'".format( host=self.host, db=self.database, user=self.user, passwd=self.passwd, port=self.port ) try: logger.debug('creating pool') self.pool = ThreadedConnectionPool(int(limits["Min"]), int(limits["Max"]), conn_params) except Exception as e: logger.exception(e.message) def get_conn(self): """ Get a connection from pool and return connection and cursor :return: conn, cursor """ logger.debug('getting connection from pool') try: conn = self.pool.getconn() cursor = conn.cursor(cursor_factory=psycopg2.extras.DictCursor) return conn, cursor except Exception as e: logger.exception(e.message) return None, None @staticmethod def execute_query(cursor, query, params): """ Execute a query on database :param cursor: cursor object :param query: database query :type query: str :param params: query parameters :type params: tuple :return: query results or bool """ logger.info('executing query') logger.debug('Cursor: {cursor}, Query: {query}'.format( cursor=cursor, query=query)) try: if query.split()[0].lower() == 'select': cursor.execute(query, params) return cursor.fetchall() else: return cursor.execute(query, params) except Exception as e: logger.exception(e.message) return False # commit changes to db permanently @staticmethod def commit_changes(conn): """ Commit changes to the databse permanently :param conn: connection object :return: bool """ logger.debug('commiting changes to database') try: return conn.commit() except Exception as e: logger.exception(e.message) return False def put_conn(self, conn): """ Put connection back to the pool :param conn: connection object :return: bool """ logger.debug('putting connection {conn} back to pool'.format(conn=conn)) try: return self.pool.putconn(conn) except Exception as e: logger.exception(e.message) return False def close_pool(self): """ Closes connection pool :return: bool """ logger.debug('closing connections pool') try: return self.pool.closeall() except Exception as e: logger.exception(e.message) return False
class Manager: """A singleton class managing the whole execution engine. It is a registry for executors and monitor engines, manages the database connection pool and provides the service API to the gui and the rest of the world. """ SINGLETON = None def __init__(self, dbconn): if Manager.SINGLETON is not None: raise RuntimeError( "Attempt to create multiple MonitorManager instances") else: Manager.SINGLETON = self self.log = logging.getLogger("Manager") # set up database pool self.dbconn = dbconn self.pool = ThreadedConnectionPool(20, 100, dbconn) # prepare database to recover jobs self.getDao().ready_active_jobs() # create executors self.execs = {} # the executor registry self.default_executor = None # the default executor for jobs self._create_executors() # monitor registry self.monitors = {} # the monitor registry self._create_monitors() self.log.info("Manager initialized and running") def _create_executors(self): exdao = ExecutorDao(self.pool) for ex in exdao.get_executors(): if ex.name in self.execs: self.log.critical("Duplicate executor name: %s", ex.name) self.execs[ex.name] = ex self.log.info("Registered executor %s", ex.name) import runner.config self.default_executor = self.execs[runner.config.default_executor()] def _create_monitors(self): for name, nw in monitor_init(): self.monitors[name] = MonitorEngine(name, nw) # Start the db monitoring task main = self.monitors['main'] self.dbMonitor = main.add_task(DbPollTask(main)) #--------------------- # Public API #--------------------- @staticmethod def initialize(dbconn): '''Initialize the manager with the given database URL.''' if Manager.SINGLETON is not None: return return Manager(dbconn) @staticmethod def shutdown(): '''Shut down the manager.''' self = Manager.SINGLETON allmon = self.monitors.values() self.log.info("Disabling monitors") self.monitors = {} for m in allmon: m.shutdown() # destroy the db connection pool self.log.info("Closing database session pool") self.pool.closeall() del self.pool @staticmethod def getDao(): '''Return a DAO for this thread.''' return JobDao(Manager.SINGLETON.pool) @staticmethod def executor(name=None): '''Return the executor with the given name. If name is None, or not given, return the default executor. Raises KeyError if the provided name does not correspond to an executor.''' if name is None: return Manager.SINGLETON.default_executor else: return Manager.SINGLETON.execs[name] @staticmethod def executors(): '''Return an iterable of all executors''' self = Manager.SINGLETON for ex in self.execs.values(): yield ex @staticmethod def monitor_engines(): '''Return an iterable of all monitor engines''' self = Manager.SINGLETON for ex in self.monitors.values(): yield ex @staticmethod def create_job(exctor, url, simhome=None): ''' Create a new job on the given executor and for the passed url and simhome. ''' self = Manager.SINGLETON fileloc = exctor.create_simulation(url, simhome) Manager.getDao().submit_new_job(exctor.name, fileloc) return fileloc @staticmethod def delete_job(job): ''' Delete a job. The job can be either a SimJob, or a fileloc or a simid. The job must be PASSIVE. ''' if isinstance(job, str): if job.startswith('/'): job = Manager.get_job_by_fileloc(job) else: job = Manager.get_job_by_simid(job) if job.state != 'PASSIVE': raise Forbidden( message='Bad job status', details='Cannot delete a job which is still executing') xtor = Manager.executor(job.executor) xtor.delete_simulation(job) Manager.getDao().delete_job(job) @staticmethod def jobs(): '''Return an iterable over all jobs''' dao = Manager.getDao() jobs = dao.get_jobs() dao.release() return jobs @staticmethod def get_job_by_fileloc(fileloc): '''Get the job for the given fileloc''' return Manager.getDao().get_job_by_fileloc(fileloc) @staticmethod def get_job_by_simid(simid): '''Get the job for the given fileloc''' fileloc = Manager.get_simhome(simid) return Manager.get_job_by_fileloc(fileloc) @staticmethod def copy_simfiles(nsdfile, simfile, fileloc): simdst = fileloc + "/sim" + ".json" nsddst = fileloc + "/nsd" + ".json" copyfile(nsdfile, nsddst) copyfile(simfile, simdst) print(fileloc) # # PR simid = 'SIMOUTPUT':<executor-name>:<homedir-basename> # @staticmethod def get_simid(xtor, simhome): '''Return the simulation PR id for the given simhome and executor''' name = xtor.name if isinstance(xtor, Executor) else xtor if name not in Manager.SINGLETON.execs: raise NotFound(details='Executor cannot be determined') return SimJob.make_simid(name, simhome) @staticmethod def get_simhome(simid): '''Return the simhome for the given simulation PR id''' xtor, bname = SimJob.break_simid(simid) homedir = Manager.executor(xtor).homedir return os.path.join(homedir, bname) @staticmethod def create_user(user): '''Create a system system user for the given object''' assert isinstance(user, User) dao = UserDao(Manager.SINGLETON.pool) dao.create_user(user) @staticmethod def get_user(username): '''Return a system system user by this name, or None.''' assert isinstance(username, str) dao = UserDao(Manager.SINGLETON.pool) return dao.get_user(username) @staticmethod def get_users(): '''Return an iterator over all system users.''' dao = UserDao(Manager.SINGLETON.pool) yield from dao.get_users() @staticmethod def update_user(username, **kwargs): '''Update the system user with the given username, setting the passed keyword arguments to the passed values.''' dao = UserDao(Manager.SINGLETON.pool) dao.update_user(username, **kwargs) @staticmethod def delete_user(username): '''Delete the system user with the given username.''' dao = UserDao(Manager.SINGLETON.pool) dao.delete_user(username)
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} flavor = Flavor(ilike=True) def __new__(cls, database_name='template1'): if database_name in cls._databases: return cls._databases[database_name] return DatabaseInterface.__new__(cls, database_name=database_name) def __init__(self, database_name='template1'): super(Database, self).__init__(database_name=database_name) self._databases.setdefault(database_name, self) def connect(self): if self._connpool is not None: return self logger.info('connect to "%s"', self.database_name) uri = parse_uri(config.get('database', 'uri')) assert uri.scheme == 'postgresql' host = uri.hostname and "host=%s" % uri.hostname or '' port = uri.port and "port=%s" % uri.port or '' name = "dbname=%s" % self.database_name user = uri.username and "user=%s" % uri.username or '' password = ("password=%s" % urllib.unquote_plus(uri.password) if uri.password else '') minconn = config.getint('database', 'minconn', default=1) maxconn = config.getint('database', 'maxconn', default=64) dsn = '%s %s %s %s %s' % (host, port, name, user, password) self._connpool = ThreadedConnectionPool(minconn, maxconn, dsn) return self def cursor(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() conn = self._connpool.getconn() if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) cursor = Cursor(self._connpool, conn, self) if readonly: cursor.execute('SET TRANSACTION READ ONLY') return cursor def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None @classmethod def create(cls, cursor, database_name): cursor.execute('CREATE DATABASE "' + database_name + '" ' 'TEMPLATE template0 ENCODING \'unicode\'') cls._list_cache = None @classmethod def drop(cls, cursor, database_name): cursor.execute('DROP DATABASE "' + database_name + '"') cls._list_cache = None def get_version(self, cursor): if self.database_name not in self._version_cache: cursor.execute('SELECT version()') version, = cursor.fetchone() self._version_cache[self.database_name] = tuple( map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.database_name] @staticmethod def dump(database_name): from trytond.tools import exec_command_pipe cmd = ['pg_dump', '--format=c', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: # if db_password is set in configuration we should pass # an environment variable PGPASSWORD to our subprocess # see libpg documentation env['PGPASSWORD'] = uri.password cmd.append(database_name) pipe = exec_command_pipe(*tuple(cmd), env=env) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception('Couldn\'t dump database!') return data @staticmethod def restore(database_name, data): from trytond.tools import exec_command_pipe database = Database().connect() cursor = database.cursor(autocommit=True) database.create(cursor, database_name) cursor.commit() cursor.close() database.close() cmd = ['pg_restore', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: env['PGPASSWORD'] = uri.password cmd.append('--dbname=' + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam() with open(tmpfile, 'wb') as fp: fp.write(data) args2 = list(args2) args2.append(' ' + tmpfile) args2 = tuple(args2) pipe = exec_command_pipe(*args2, env=env) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception('Couldn\'t restore database') database = Database(database_name).connect() cursor = database.cursor() if not cursor.test(): cursor.close() database.close() raise Exception('Couldn\'t restore database!') cursor.close() database.close() Database._list_cache = None return True @staticmethod def list(cursor): now = time.time() timeout = config.getint('session', 'timeout') res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res uri = parse_uri(config.get('database', 'uri')) db_user = uri.username or os.environ.get('PGUSER') if not db_user and os.name == 'posix': db_user = pwd.getpwuid(os.getuid())[0] if db_user: cursor.execute( "SELECT datname " "FROM pg_database " "WHERE datdba = (" "SELECT usesysid " "FROM pg_user " "WHERE usename=%s) " "AND datname not in " "('template0', 'template1', 'postgres') " "ORDER BY datname", (db_user, )) else: cursor.execute("SELECT datname " "FROM pg_database " "WHERE datname not in " "('template0', 'template1','postgres') " "ORDER BY datname") res = [] for db_name, in cursor.fetchall(): db_name = db_name.encode('utf-8') try: database = Database(db_name).connect() except Exception: continue cursor2 = database.cursor() if cursor2.test(): res.append(db_name) cursor2.close(close=True) else: cursor2.close(close=True) database.close() Database._list_cache = res Database._list_cache_timestamp = now return res @staticmethod def init(cursor): from trytond.modules import get_module_info sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ('ir', 'res', 'webdav'): state = 'uninstalled' if module in ('ir', 'res'): state = 'to install' info = get_module_info(module) cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')') module_id = cursor.fetchone()[0] cursor.execute( 'INSERT INTO ir_module ' '(id, create_uid, create_date, name, state) ' 'VALUES (%s, %s, now(), %s, %s)', (module_id, 0, module, state)) for dependency in info.get('depends', []): cursor.execute( 'INSERT INTO ir_module_dependency ' '(create_uid, create_date, module, name) ' 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))
#!/usr/bin/python3 from psycopg2.pool import ThreadedConnectionPool from multiprocessing.pool import ThreadPool pool = ThreadedConnectionPool(0, 2, host="localhost", user="******", dbname="car_portal") queries = ["SELECT 1 FROM pg_sleep(10)", "SELECT 2 FROM pg_sleep(10)", "SELECT 3 FROM pg_sleep(10)"] def execute_query(query): conn = pool.getconn(query) with conn.cursor() as cur: cur.execute(query) row = cur.fetchone() value = row[0] pool.putconn(conn, query) return value thread_pool = ThreadPool(2) results = thread_pool.map(execute_query, queries) print(results) pool.closeall()
class connection(object): def __init__(self,url=None,hstore=False,log=None,logf=None,min=1,max=5, default_cursor=DictCursor): params = urlparse.urlparse(url or os.environ.get('DATABASE_URL') or 'postgres://localhost/') self.pool = ThreadedConnectionPool(min,max, database=params.path[1:], user=params.username, password=params.password, host=params.hostname, port=params.port, ) self.hstore = hstore self.log = log self.logf = logf or (lambda cursor : cursor.query) self.default_cursor = default_cursor self.prepared_statement_id = 0 def prepare(self,statement,params=None,name=None,call_type=None): """ >>> db = connection() >>> p1 = db.prepare('SELECT name FROM doctest_t1 WHERE id = $1') >>> p2 = db.prepare('UPDATE doctest_t1 set name = $2 WHERE id = $1',('int','text')) >>> db.execute(p2,(1,'xxxxx')) 1 >>> db.query_one(p1,(1,)) ['xxxxx'] >>> db.execute(p2,(1,'aaaaa')) 1 >>> db.query_one(p1,(1,)) ['aaaaa'] """ if not name: self.prepared_statement_id += 1 name = '_pstmt_%03.3d' % self.prepared_statement_id if params: params = '(' + ','.join(params) + ')' else: params = '' with self.cursor() as c: c.execute('PREPARE %s %s AS %s' % (name,params,statement)) if call_type is None: if statement.lower().startswith('select'): call_type = 'query' else: call_type = 'execute' return PreparedStatement(self,name,call_type) def shutdown(self): if self.pool: self.pool.closeall() self.pool = None def cursor(self,cursor_factory=None): return cursor(self.pool, cursor_factory or self.default_cursor, self.hstore, self.log, self.logf) def __del__(self): self.shutdown() def __getattr__(self,name): def _wrapper(*args,**kwargs): with self.cursor() as c: return getattr(c,name)(*args,**kwargs) return _wrapper
class Database(object): def __init__(self, db_connection, connection_number, app, min_db_connections=1): self.db_connection = db_connection self.min_connections = min_db_connections self.connection_number = connection_number self.pool = ThreadedConnectionPool(min_db_connections, connection_number, db_connection) self.app = app self.last_seen_process_id = os.getpid() self.needs_change = True def dispose(self): self.pool.closeall() def get_connection(self): if self.needs_change: current_pid = os.getpid() if not (current_pid == self.last_seen_process_id): self.last_seen_process_id = current_pid self.pool.closeall() self.pool = ThreadedConnectionPool(self.min_connections, self.connection_number, self.db_connection) self.needs_change = False return self.pool.getconn() def run_script(self, script): from app import app success = True conn = self.get_connection() try: cursor = conn.cursor() cursor.execute(script) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(script)) conn.rollback() self.pool.putconn(conn) success = False return success def add_tracks(self, dicts, table): conn = self.get_connection() cursor = conn.cursor() for track in tracks: columns = track.keys() values = [track[column] for column in columns] sql = "INSERT INTO tracks (%s) VALUES %s" cursor.execute(sql, (AsIs(','.join(columns)), tuple(values))) conn.commit() main_database.Database.run_script def execute_update(self, sql, vars=None): conn = self.get_connection() success = True try: cursor = conn.cursor() cursor.execute(sql, vars) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) success = False return success def value_exists(self, table, field, value): exists = False conn = self.get_connection() try: cursor = conn.cursor() sql = "SELECT * FROM %s WHERE %s = %s " vars = (AsIs(table), AsIs(field), value) cursor.execute(sql, vars) results = cursor.fetchall() self.pool.putconn(conn) if len(results) > 0: exists = True except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return exists def delete_table(self, table): '''Deletes the table from the database table: Thr name of the table to delete Returns: True - if the table was deleted, otherwise False. ''' conn = self.get_connection() sql = "DROP TABLE public.{}".format(table) try: cursor = conn.cursor() cursor.execute(sql) conn.commit() self.pool.putconn(conn) return True except Exception as e: self.app.logger.exception( "The table {} could not be deleted using the SQL:\n{}".format( table, sql)) conn.rollback() self.pool.putconn(conn) return False def remove_columns(self, table, columns): '''Deletes columns from the specified table Args: columns(list[str]): A list of column names to delete Returns: True - if the columns were deleted, otherwise False ''' conn = self.get_connection() results = [] sql = "" try: cursor = conn.cursor() for col in columns: sql = "ALTER TABLE public.{} DROP COLUMN {}".format(table, col) cursor.execute(sql) conn.commit() self.pool.putconn(conn) return True except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return False def add_columns(self, table, columns): '''Adds columns to to the specified table Args: columns: A list of dictionaries with 'name' and 'datatype' (text,integer or double) and optionally default (specifies default value). There is no check on the column name, so it has to be compatible with postgresql Returns: True - if the columns were added, otherwise False ''' conn = self.get_connection() results = [] sql = "" try: cursor = conn.cursor() for col in columns: if col['datatype'] == "double": col['datatype'] = "double precision" sql = "ALTER TABLE public.{} ADD COLUMN {} {}".format( table, col['name'], col['datatype']) if col.get("default"): val = col.get("default") if col['datatype'] == "text": val = "'" + val + "'" sql += " default " + val cursor.execute(sql) conn.commit() self.pool.putconn(conn) return True except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return False def execute_query(self, sql, vars=None): conn = self.get_connection() results = [] try: cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) cursor.execute(sql, vars) results = cursor.fetchall() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) self.pool.putconn(conn) return results def get_sql(self, sql, vars=None): conn = self.get_connection() try: cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) result = cursor.mogrify(sql, vars) self.pool.putconn(conn) return result except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) self.pool.putconn(conn) def delete_by_id(self, table, ids): conn = self.get_connection() try: cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) sql = "DELETE FROM %s WHERE id = ANY (%s)" cursor.execute(sql, (AsIs(table), ids)) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) def execute_delete(self, sql, vars=None): conn = self.get_connection() try: cursor = conn.cursor() cursor.execute(sql, vars) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return False return True def execute_insert(self, sql, vars=None, ret_id=True): if ret_id: sql += " RETURNING id" conn = self.get_connection() new_id = -1 try: cursor = conn.cursor() cursor.execute(sql, vars) conn.commit() if ret_id: new_id = cursor.fetchone()[0] self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return new_id def get_tracks(self, track_ids=[], field="track_id", proxy=True): '''Get track info for all the supplies track ids Args: track_ids(list): A list of track_ids (or ids) to retreive field(Optional[str]): The field to query the databse (track_id by default) proxy(Optional[boolean]) If true (default) the the track's proxy url as defined in the config will be returned ''' from app import app conn = self.get_connection() cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) sql = "SELECT id,track_id,url,type,short_label,color FROM tracks WHERE %s = ANY(%s)" cursor.execute(sql, (AsIs(field), track_ids)) results = cursor.fetchall() self.pool.putconn(conn) if not results: return [] t_p = app.config.get("TRACK_PROXIES") if t_p and proxy: for item in results: for p in t_p: item['url'] = item['url'].replace(p, t_p[p]) return results def insert_dicts_into_table(self, dicts, table): conn = self.get_connection() try: cursor = conn.cursor() for item in dicts: sql = "INSERT INTO {} (%s) VALUES %s".format(table) columns = item.keys() values = [item[column] for column in columns] cursor.execute(sql, (AsIs(','.join(columns)), tuple(values))) conn.commit() cursor.close() self.pool.putconn(conn) except Exception as e: conn.rollback() self.pool.putconn(conn) raise def update_table_with_dicts(self, dicts, table): '''Updates the table with with the values in the supplied dictionaries (the keys being the column names) Args: dicts (list[dict]): A list of dictionaries containing key(column name) to value (new value). Each dictionary should also contain an id key with the id of the row. Returns: True if the update was successful, otherwise False. ''' conn = self.get_connection() sql = "" complete = True try: cursor = conn.cursor() for item in dicts: id = item['id'] del item['id'] sql = "UPDATE {} SET (%s) = %s WHERE id= %s".format(table) columns = item.keys() values = [item[column] for column in columns] cursor.execute(sql, (AsIs(','.join(columns)), tuple(values), id)) conn.commit() cursor.close() except: self.app.logger.exception("update could not be run") conn.rollback() complete = False self.pool.putconn(conn) return complete def insert_dict_into_table(self, dic, table): conn = self.get_connection() try: cursor = conn.cursor() sql = "INSERT INTO {} (%s) VALUES %s RETURNING id".format(table) columns = dic.keys() values = [dic[column] for column in columns] cursor.execute(sql, (AsIs(','.join(columns)), tuple(values))) conn.commit() new_id = cursor.fetchone()[0] cursor.close() self.pool.putconn(conn) return new_id except: self.app.logger.exception("update could not be run") self.pool.putconn(conn) return -1
class DBUtils: __singleton = None __lock = Lock() QUOTES_UPSERT_SQL = f""" INSERT INTO quotes (NAME, CODE, OPENING_PRICE, LATEST_PRICE, QUOTE_CHANGE, CHANGE, VOLUME, TURNOVER, AMPLITUDE, TURNOVER_RATE, "PE_ratio", VOLUME_RATIO, MAX_PRICE, MIN_PRICE, CLOSING_PRICE, "PB_ratio", MARKET, TIME) VALUES ({','.join(['%s'] * 18)}) ON CONFLICT ON CONSTRAINT quotes_pkey DO UPDATE SET latest_price = excluded.latest_price, change = excluded.change, volume = excluded.volume, turnover = excluded.turnover, amplitude = excluded.amplitude, turnover_rate = excluded.turnover_rate, volume_ratio = excluded.volume_ratio, max_price = excluded.max_price, min_price = excluded.min_price """ COMPANIES_UPSERT_SQL = f""" INSERT INTO companies (code, name, intro, manage, ssrq, clrq, fxl, fxfy, mgfxj, fxzsz, srkpj, srspj, srhsl, srzgj, djzql, wxpszql, mjzjje, zczb) VALUES ({','.join(['%s'] * 18)}) ON CONFLICT ON CONSTRAINT companies_pkey DO UPDATE SET fxl = excluded.fxl, fxfy = excluded.fxfy, mgfxj = excluded.mgfxj, fxzsz = excluded.fxzsz, djzql = excluded.djzql, wxpszql = excluded.wxpszql, zczb = excluded.zczb """ CODES_QUERY_SQL = """ SELECT DISTINCT code FROM quotes """ MANAGE_QUERY_SQL = """ SELECT manage FROM companies order by code """ POS_VEC_UPDATE_SQL = """ UPDATE companies SET pos_vec = %s WHERE code = %s """ CODES_MANAGE_QUERY_SQL = """ select code, manage from companies """ POS_VEC_QUERY_SQL = """ SELECT pos_vec FROM companies WHERE code = %s """ NEED_UPDADE_CODE_SQL = """ SELECT DISTINCT code FROM quotes EXCEPT SELECT code FROM companies """ MAIN_TARGET_INSERT_SQL = f""" INSERT INTO main_target (code, date, jbmgsy, kfmgsy, xsmgsy, mgjzc, mggjj, mgwfply, mgjyxjl, yyzsr, mlr, gsjlr, kfjlr, yyzsrtbzz, gsjlrtbzz, kfjlrtbzz, yyzsrgdhbzz, gsjlrgdhbzz, kfjlrgdhbzz, jqjzcsyl, tbjzcsyl, tbzzcsyl, mll, jll, sjsl, yskyysr, jyxjlyysr, xsxjlyysr, zzczzl, yszkzzts, chzzts, zcfzl, ldzczfz, ldbl, sdbl) VALUES ({','.join(['%s'] * 35)}) ON CONFLICT DO NOTHING """ def __init__(self, database_config): self.database_config = database_config pool_config = self.database_config['pool'] conn_config = self.database_config['conn'] self.conn_pool = ThreadedConnectionPool(minconn=pool_config['min'], maxconn=pool_config['max'], **conn_config) def closeall(self): self.conn_pool.closeall() def upsert_quotes(self, quotes): """更新行情""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: try: cur.executemany(self.QUOTES_UPSERT_SQL, quotes) conn.commit() except errors.UniqueViolation as e: print("当前记录已存在") except errors.NotNullViolation as e: print("违反非空约束") finally: self.conn_pool.putconn(conn) def upsert_company(self, company): """添加公司信息""" conn = self.conn_pool.getconn() try: value = [ company['code'], company['name'], company['intro'], company['manage'], company['ssrq'], company['clrq'], company['fxl'], company['fxfy'], company['mgfxj'], company['fxzsz'], company['srkpj'], company['srspj'], company['srhsl'], company['srzgj'], company['djzql'], company['wxpszql'], company['mjzjje'], company['zczb'] ] with conn.cursor() as cur: cur.execute(self.COMPANIES_UPSERT_SQL, value) conn.commit() except errors.UniqueViolation as e: print("当前记录已存在") except errors.NotNullViolation as e: print("违反非空约束") finally: self.conn_pool.putconn(conn) def get_all_codes(self): """获得所有公司代码""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.CODES_QUERY_SQL) codes = cur.fetchall() return codes except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def get_all_manage(self): """获得所有公司经营范围""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.MANAGE_QUERY_SQL) manage = cur.fetchall() return manage except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def get_all_codes_manage(self): """获取公司所有代码和经营范围""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.CODES_MANAGE_QUERY_SQL) result = cur.fetchall() return result except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def update_company_pos_vec(self, code, pos_vec): """设置公司在行业语义空间中的位置""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.POS_VEC_UPDATE_SQL, (pos_vec, code)) conn.commit() except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def get_pos_vec(self, code): """获得语义空间位置向量""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.POS_VEC_QUERY_SQL, (code, )) pos_vec = cur.fetchone() return pos_vec finally: self.conn_pool.putconn(conn) def get_need_update_codes(self): """获得所有需要更新的公司代码""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.NEED_UPDADE_CODE_SQL) result = cur.fetchall() return result finally: self.conn_pool.putconn(conn) def insert_main_target(self, targets): """保存主要指标""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.executemany(self.MAIN_TARGET_INSERT_SQL, targets) conn.commit() except errors.UniqueViolation as e: print("当前记录已存在") except errors.NotNullViolation as e: print("违反非空约束") finally: self.conn_pool.putconn(conn) @staticmethod def init(database_config): DBUtils.__lock.acquire() try: if DBUtils.__singleton is None: DBUtils.__singleton = DBUtils(database_config) except Exception: traceback.print_exc() finally: DBUtils.__lock.release() return DBUtils.__singleton
def handle(self, *args, **options): self.mturk_email = getattr(settings, 'MTURK_AUTH_EMAIL', None) self.mturk_password = getattr(settings, 'MTURK_AUTH_PASSWORD', None) _start_time = time.time() pid = Pid('mturk_crawler', True) log.info('crawler started: %s;;%s', args, options) if options.get('mturk_email'): self.mturk_email = options['mturk_email'] if options.get('mturk_password'): self.mturk_password = options['mturk_password'] if options.get('logconf', None): self.setup_logging(options['logconf']) if options.get('debug', False): self.setup_debug() print 'Current proccess pid: %s' % pid.actual_pid print ('To debug, type: python -c "import os,signal; ' 'os.kill(%s, signal.SIGUSR1)"\n') % pid.actual_pid self.maxworkers = options['workers'] if self.maxworkers > 9: # If you want to remote this limit, don't forget to change dbpool # object maximum number of connections. Each worker should fetch # 10 hitgroups and spawn single task for every one of them, that # will get private connection instance. So for 9 workers it's # already 9x10 = 90 connections required # # Also, for too many workers, amazon isn't returning valid data # and retrying takes much longer than using smaller amount of # workers sys.exit('Too many workers (more than 9). Quit.') start_time = datetime.datetime.now() hits_available = tasks.hits_mainpage_total() groups_available = tasks.hits_groups_total() # create crawl object that will be filled with data later crawl = Crawl.objects.create( start_time=start_time, end_time=datetime.datetime.now(), success=True, hits_available=hits_available, hits_downloaded=0, groups_available=groups_available, groups_downloaded=groups_available) log.debug('fresh crawl object created: %s', crawl.id) # fetch those requester profiles so we could decide if their hitgroups # are public or not reqesters = RequesterProfile.objects.all_as_dict() dbpool = ThreadedConnectionPool(10, 90, 'dbname=%s user=%s password=%s' % (settings.DATABASE_NAME, settings.DATABASE_USER, settings.DATABASE_PASSWORD)) # collection of group_ids that were already processed - this should # protect us from duplicating data processed_groups = set() total_reward = 0 hitgroups_iter = self.hits_iter() for hg_pack in hitgroups_iter: jobs = [] for hg in hg_pack: j = gevent.spawn(tasks.process_group, hg, crawl.id, reqesters, processed_groups, dbpool) jobs.append(j) total_reward += hg['reward'] * hg['hits_available'] log.debug('processing pack of hitgroups objects') gevent.joinall(jobs, timeout=20) # check if all jobs ended successfully for job in jobs: if not job.ready(): log.error('Killing job: %s', job) job.kill() if len(processed_groups) >= groups_available: log.info('Skipping empty groups.') # there's no need to iterate over empty groups.. break break # amazon does not like too many requests at once, so give them a # quick rest... gevent.sleep(1) dbpool.closeall() # update crawler object crawl.groups_downloaded = len(processed_groups) crawl.end_time = datetime.datetime.now() crawl.save() work_time = time.time() - _start_time log.info('created crawl id: %s', crawl.id) log.info('total reward value: %s', total_reward) log.info('processed hits groups downloaded: %s', len(processed_groups)) log.info('processed hits groups available: %s', groups_available) log.info('work time: %.2fsec', work_time) crawl_time_warning = 300 if work_time > crawl_time_warning: log.warning("Crawl took {0} s which seems a bit too long (more than" "{0} s), you might consider checking if correct mturk account" " is used.".format(crawl_time_warning)) if crawl.groups_downloaded < groups_available * 0.9: log.warning('More than 10% of hit groups were not downloaded, ' 'please check mturk account configuration and/or if there are ' 'any network-related problems.') crawl_downloaded_pc = 0.6 if crawl.groups_downloaded < groups_available * crawl_downloaded_pc: log.warning("This crawl contains far too few groups downloaded to " "available: ({0} < {1} * {2}) and will be considered as " "erroneous".format(crawl.groups_downloaded, groups_available, crawl_downloaded_pc))
class PostgreSQLRepository(Repository): """ PostgreSQL based implementation of a Todos repository. """ class SQL: CREATE_UUID_EXTENSION = """ CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; """ CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS todos ( id uuid NOT NULL PRIMARY KEY DEFAULT uuid_generate_v1mc(), text text NOT NULL, active bool NOT NULL ); """ STATS = """ SELECT active, count(*) FROM todos GROUP BY active; """ GET = """ SELECT text, active FROM todos WHERE id = %(id)s; """ LIST = """ SELECT id, text, active FROM todos ORDER BY id; """ INSERT = """ INSERT INTO todos(text, active) VALUES (%(text)s, %(active)s) RETURNING id; """ EDIT_TEXT = """ UPDATE todos SET text = %(text)s WHERE id = %(id)s; """ ACTIVATE = """ UPDATE todos SET active = TRUE WHERE id = %(id)s; """ DEACTIVATE = """ UPDATE todos SET active = FALSE WHERE id = %(id)s; """ DELETE = """ DELETE FROM todos WHERE id = %(id)s; """ CLEAN = """ TRUNCATE TABLE todos; """ @staticmethod def factory(max_connections: int = 10): connection_url = os.environ.get( "POSTGRESQL_CONNECTION_URL", "postgres://*****:*****@localhost:5432/postgres", ) return PostgreSQLRepository(connection_url=connection_url, max_connections=max_connections) def __init__(self, connection_url: str, max_connections: int): self._connection_url = connection_url self._max_connections = max_connections self._pool: Optional[ThreadedConnectionPool] = None register_uuid() def connect(self) -> None: assert self._pool is None self._pool = ThreadedConnectionPool(minconn=1, maxconn=self._max_connections, dsn=self._connection_url) def disconnect(self) -> None: assert self._pool is not None self._pool.closeall() self._pool = None def initialize(self) -> None: with self._connection() as conn: with conn: with conn.cursor() as curs: curs.execute(self.SQL.CREATE_UUID_EXTENSION) with conn.cursor() as curs: curs.execute(self.SQL.CREATE_TABLE) def stats(self) -> Stats: data = {} with self._cursor() as curs: curs.execute(self.SQL.STATS) for row in curs: active = row[0] count = row[1] data[active] = count return Stats(active=data.get(True, 0), inactive=data.get(False, 0)) def get(self, id_: UUID) -> Optional[Todo]: with self._cursor() as curs: curs.execute(self.SQL.GET, {"id": id_}) row = curs.fetchone() if row is None: return None else: return Todo(id=id_, text=row[0], active=row[1]) @rare_delay(delay=1.5, probability=0.2) def list(self) -> Tuple[Todo, ...]: todos = [] with self._cursor() as curs: curs.execute(self.SQL.LIST) for row in curs: todo = Todo(id=row[0], text=row[1], active=row[2]) todos.append(todo) return tuple(todos) @random_delay(min_delay=0.5, max_delay=2.0) def insert(self, text: str) -> UUID: with self._cursor() as curs: curs.execute(self.SQL.INSERT, {"text": text, "active": True}) return curs.fetchone()[0] def edit_text(self, id_: UUID, text: str) -> bool: with self._cursor() as curs: curs.execute(self.SQL.EDIT_TEXT, {"id": id_, "text": text}) return curs.rowcount > 0 def activate(self, id_: UUID) -> bool: with self._cursor() as curs: curs.execute(self.SQL.ACTIVATE, {"id": id_, "active": True}) return curs.rowcount > 0 @rare_delay(delay=3.0, probability=0.1) def deactivate(self, id_: UUID) -> bool: with self._cursor() as curs: curs.execute(self.SQL.DEACTIVATE, {"id": id_, "active": False}) return curs.rowcount > 0 def delete(self, id_: UUID) -> bool: with self._cursor() as curs: curs.execute(self.SQL.DELETE, {"id": id_}) return curs.rowcount > 0 def _clean(self) -> None: with self._cursor() as curs: curs.execute(self.SQL.CLEAN) @contextmanager def _connection(self) -> ContextManager[connection]: assert self._pool is not None conn = self._pool.getconn() try: yield conn finally: self._pool.putconn(conn) @contextmanager def _cursor(self) -> ContextManager[cursor]: with self._connection() as conn: with conn: with conn.cursor() as curs: yield curs
class PlacesSearch(object): """Class for performing the GEPlaces search. Search can be performed either as a single scope search where the search string is either a city or country or either as a double scope search where the search is performed for a city in a state/country combination. Valid Inputs are : q = santa clara q = Atlanta, Georgia """ def __init__(self): """Inits GEPlaces. Initializes the logger "ge_search". Initializes templates for kml,placemark templates for the KML output. Initializes parameters for establishing a connection to the database. """ self.utils = utils.SearchUtils() constants = geconstants.Constants() configs = self.utils.GetConfigs( os.path.join(geconstants.SEARCH_CONFIGS_DIR, "GEPlacesSearch.conf")) self._jsonp_call = self.utils.jsonp_functioncall self._geom = """ <name>%s</name> <styleUrl>%s</styleUrl> <description>%s</description> %s\ """ self._json_geom = """ { "name": "%s", "description": "%s", %s }\ """ self._placemark_template = self.utils.placemark_template self._kml_template = self.utils.kml_template style_template = self.utils.style_template self._json_template = self.utils.json_template self._json_placemark_template = self.utils.json_placemark_template self._city_query_template = ( Template(constants.city_query)) self._country_query_template = ( Template(constants.country_query)) self._city_and_country_name_query_template = ( Template(constants.city_and_country_name_query)) self._city_and_country_code_query_template = ( Template(constants.city_and_country_code_query)) self._city_and_subnational_name_query_template = ( Template(constants.city_and_subnational_name_query)) self._city_and_subnational_code_query_template = ( Template(constants.city_and_subnational_code_query)) self.logger = self.utils.logger self._user = configs.get("user") self._hostname = configs.get("host") self._port = configs.get("port") self._database = configs.get("databasename") if not self._database: self._database = constants.defaults.get("places.database") self._pool = ThreadedConnectionPool( int(configs.get("minimumconnectionpoolsize")), int(configs.get("maximumconnectionpoolsize")), database=self._database, user=self._user, host=self._hostname, port=int(self._port)) self.style = style_template.substitute( balloonBgColor=configs.get("balloonstyle.bgcolor"), balloonTextColor=configs.get("balloonstyle.textcolor"), balloonText=configs.get("balloonstyle.text"), iconStyleScale=configs.get("iconstyle.scale"), iconStyleHref=configs.get("iconstyle.href"), lineStyleColor=configs.get("linestyle.color"), lineStyleWidth=configs.get("linestyle.width"), polyStyleColor=configs.get("polystyle.color"), polyStyleColorMode=configs.get("polystyle.colormode"), polyStyleFill=configs.get("polystyle.fill"), polyStyleOutline=configs.get("polystyle.outline"), listStyleHref=configs.get("iconstyle.href")) def RunPGSQLQuery(self, query, params): """Submits the query to the database and returns tuples. Note: variables placeholder must always be %s in query. Warning: NEVER use Python string concatenation (+) or string parameters interpolation (%) to pass variables to a SQL query string. e.g. SELECT vs_url FROM vs_table WHERE vs_name = 'default_ge'; query = "SELECT vs_url FROM vs_table WHERE vs_name = %s" parameters = ["default_ge"] Args: query: SQL SELECT statement. params: sequence of parameters to populate into placeholders. Returns: Query execution status and list of tuples of query results. Raises: psycopg2.Error/Warning in case of error. """ con = None cursor = None query_results = [] query_status = False self.logger.debug("Querying the database %s,at port %s,as user %s on " "hostname %s", self._database, self._port, self._user, self._hostname) try: con = self._pool.getconn() if con: cursor = con.cursor() cursor.execute(query, params) for row in cursor: if len(row) == 1: query_results.append(row[0]) else: query_results.append(row) query_status = True except psycopg2.pool.PoolError as e: self.logger.error("Exception while querying the database %s, %s", self._database, e) raise exceptions.PoolConnectionException( "Pool Error - Unable to get a connection from the pool.") except psycopg2.Error as e: self.logger.error("Exception while querying the database %s, %s", self._database, e) finally: if con: self._pool.putconn(con) return query_status, query_results def RunCityGeocoder(self, search_query, response_type): """Performs a query search on the 'cities' table. Args: search_query: the query to be searched, in small case. response_type: Response type can be KML or JSONP, depending on the client. Returns: tuple containing total_city_results: Total number of rows returned from querying the database. city_results: Query results as a list. """ city_results_list = [] search_city_view = self.utils.GetCityView(search_query) params = search_query.split(",") accum_func = self.utils.GetAccumFunc(response_type) city_query = self._city_query_template.substitute( FUNC=accum_func, CITY_VIEW=search_city_view) query_status, query_results = self.RunPGSQLQuery(city_query, params) total_city_results = len(query_results) if query_status: for entry in query_results: city_results = {} if entry[3] == 0: city_population = "unknown" else: city_population = str(entry[3]) name = "%s, %s" % (entry[1], entry[4]) styleurl = "#placemark_label" country_name = "Country: %s" % (entry[2]) population = "Population: %s" % (city_population) description = "%s<![CDATA[<br/>]]>%s" % (country_name, population) geom = str(entry[0]) city_results["name"] = name city_results["styleurl"] = styleurl city_results["description"] = description city_results["geom"] = geom city_results["geom_type"] = entry[5] city_results_list.append(city_results) return total_city_results, city_results_list def RunCountryGeocoder(self, search_query, response_type): """Performs a query search on the 'countries' table. Args: search_query: the query to be searched, in smallcase. response_type: Response type can be KML or JSONP, depending on the client. Returns: tuple containing total_country_results: Total number of rows returned from querying the database. country_results_list: Query results as a list. """ country_results_list = [] params = search_query.split(",") accum_func = self.utils.GetAccumFunc(response_type) country_query = self._country_query_template.substitute(FUNC=accum_func) query_status, query_results = self.RunPGSQLQuery(country_query, params) total_country_results = len(query_results) if query_status: for entry in query_results: country_results = {} if entry[5] == 0: country_population = "unknown" else: country_population = str(entry[5]) name = entry[0] styleurl = "#placemark_label" capital = "Capital: %s" % (entry[3]) population = "Population: %s" % (country_population) description = "%s<![CDATA[<br/>]]>%s" % (capital, population) geom = str(entry[2]) country_results["name"] = name country_results["styleurl"] = styleurl country_results["description"] = description country_results["geom"] = geom country_results["geom_type"] = entry[8] country_results_list.append(country_results) return total_country_results, country_results_list def SingleScopeSearch(self, search_query, response_type): """Performs a query search on the 'cities' and 'countries' tables. Input contains either city name or state name like "q=santa clara" or "q=California". Args: search_query: the query to be searched, in smallcase. response_type: Response type can be KML or JSONP, depending on the client. Returns: total_results: Total query results. single_scope_results: A list of dictionaries containing the 'Placemarks'. """ total_results = 0 single_scope_results = [] country_count, country_results = self.RunCountryGeocoder( search_query, response_type) total_results += country_count city_count, city_results = self.RunCityGeocoder( search_query, response_type) total_results += city_count single_scope_results = country_results + city_results self.logger.info("places search returned %s results", total_results) return total_results, single_scope_results def DoubleScopeSearch(self, search_query, response_type): """Performs a query search on the 'cities' and 'countries' tables. Input contains both city name and state name, like "q=santa clara,California". Args: search_query: the query which contains both city name and state name to be search in the geplaces database. response_type: Response type can be KML or JSONP, depending on the client. Returns: total_results: Total query results. double_scope_results: A list of dictionaries containing the 'Placemarks'. """ total_results = 0 double_scope_results = [] search_city_view = self.utils.GetCityView(search_query) params = [entry.strip() for entry in search_query.split(",")] params = params[0:2]*4 accum_func = self.utils.GetAccumFunc(response_type) city_and_country_name_query = ( self._city_and_country_name_query_template.substitute( FUNC=accum_func, CITY_VIEW=search_city_view)) city_and_country_code_query = ( self._city_and_country_code_query_template.substitute( FUNC=accum_func, CITY_VIEW=search_city_view)) city_and_subnational_name_query = ( self._city_and_subnational_name_query_template.substitute( FUNC=accum_func, CITY_VIEW=search_city_view)) city_and_subnational_code_query = ( self._city_and_subnational_code_query_template.substitute( FUNC=accum_func, CITY_VIEW=search_city_view)) query = ("%s UNION %s UNION %s UNION %s ORDER BY population DESC" % (city_and_country_name_query, city_and_country_code_query, city_and_subnational_name_query, city_and_subnational_code_query)) query_status, query_results = self.RunPGSQLQuery(query, params) total_results += len(query_results) if query_status: for entry in query_results: temp_results = {} if entry[1] == 0: population = "unknown" else: population = str(entry[1]) name = "%s, %s" % (entry[2], entry[4]) country_name = "Country: %s" % (entry[3]) population = "Population: %s" % (population) description = "%s<![CDATA[<br/>]]>%s" % (country_name, population) styleurl = "#placemark_label" geom = str(entry[0]) temp_results["name"] = name temp_results["styleurl"] = styleurl temp_results["description"] = description temp_results["geom"] = geom temp_results["geom_type"] = entry[5] double_scope_results.append(temp_results) self.logger.info("places search returned %s results", total_results) return total_results, double_scope_results def ConstructKMLResponse(self, search_results, original_query): """Prepares KML response. KML response has the below format: <kml> <Folder> <name/> <StyleURL> --- </StyleURL> <Point> <coordinates/> </Point> </Folder> </kml> Args: search_results: Query results from the geplaces database. original_query: Search query as entered by the user. Returns: kml_response:KML formatted response. """ search_placemarks = "" kml_response = "" lookat_info = "" set_first_element_lookat = True # folder name should include the query parameter(q) if 'displayKeys' # is present in the URL otherwise not. if self.display_keys_string: folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) else: folder_name = ("Grouped results:<![CDATA[<br/>]]> (%s)" % (str(len(search_results)))) fly_to_first_element = str(self.fly_to_first_element).lower() == "true" for result in search_results: geom = "" placemark = "" geom = self._geom % ( result["name"], result["styleurl"], result["description"], result["geom"]) # Add <LookAt> for POINT geometric types only. # TODO: Check if <LookAt> can be added for # LINESTRING and POLYGON types. if result["geom_type"] != "POINT": set_first_element_lookat = False if fly_to_first_element and set_first_element_lookat: lookat_info = self.utils.GetLookAtInfo(result["geom"]) set_first_element_lookat = False placemark = self._placemark_template.substitute(geom=geom) search_placemarks += placemark kml_response = self._kml_template.substitute( foldername=folder_name, style=self.style, lookat=lookat_info, placemark=search_placemarks) self.logger.info("KML response successfully formatted") return kml_response def ConstructJSONPResponse(self, search_results, original_query): """Prepares JSONP response. { "Folder": { "name": "Latitude X Longitude Y", "Placemark": { "Point": { "coordinates": "X,Y" } } } } Args: search_results: Query results from the geplaces table. original_query: Search query as entered by the user. Returns: jsonp_response: JSONP formatted response. """ search_placemarks = "" search_geoms = "" geoms = "" json_response = "" jsonp_response = "" folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) for count, result in enumerate(search_results): geom = "" geom = self._json_geom % ( result["name"], result["description"], # we skip the beginning and end curly braces '{' # that are seen in the response from querying the database. result["geom"][1:-1]) if count < (len(search_results) - 1): geom += "," geoms += geom if len(search_results) == 1: search_geoms = geoms else: search_geoms = "["+ geoms +"]" search_placemarks = self._json_placemark_template.substitute( geom=search_geoms) json_response = self._json_template.substitute( foldername=folder_name, json_placemark=search_placemarks) # Escape single quotes from json_response. json_response = json_response.replace("'", "\\'") jsonp_response = self._jsonp_call % (self.f_callback, json_response) self.logger.info("JSONP response successfully formatted") return jsonp_response def HandleSearchRequest(self, environ): """Fetches the search tokens from form and performs the places search. Args: environ: A list of environment variables as supplied by the WSGI interface to the places search application interface. Returns: search_results: A KML/JSONP formatted string which contains search results. """ search_results = "" search_status = False # Fetch all the attributes provided by the user. parameters = self.utils.GetParameters(environ) response_type = self.utils.GetResponseType(environ) # Retrieve the function call back name for JSONP response. self.f_callback = self.utils.GetCallback(parameters) original_query = self.utils.GetValue(parameters, "q") # Fetch additional query parameters 'flyToFirstElement' and # 'displayKeys' from URL. self.fly_to_first_element = self.utils.GetValue( parameters, "flyToFirstElement") self.display_keys_string = self.utils.GetValue( parameters, "displayKeys") if original_query: (search_status, search_results) = self.DoSearch( original_query, response_type) else: self.logger.debug("Empty search query received") if not search_status: folder_name = "No results were returned." search_results = self.utils.NoSearchResults( folder_name, self.style, response_type, self.f_callback) return (search_results, response_type) def DoSearch(self, original_query, response_type): """Performs the places search and returns the results. Args: original_query: A string containing the search query as entered by the user. response_type: Response type can be KML or JSONP, depending on the client. Returns: tuple containing search_status: Whether search could be performed. search_results: A KML/JSONP formatted string which contains search results. """ search_status = False search_results = "" query_results = "" total_results = 0 search_query = original_query.strip().lower() if search_query: if len(search_query.split(",")) > 1: if len(search_query.split(",")) > 2: self.logger.warning("Extra search parameters ignored:%s", ",".join(search_query.split(",")[2:])) search_query = ",".join(search_query.split(",")[:2]) original_query = ",".join(original_query.split(",")[:2]) total_results, query_results = self.DoubleScopeSearch( search_query, response_type) else: total_results, query_results = self.SingleScopeSearch( search_query, response_type) if total_results > 0: if response_type == "KML": search_results = self.ConstructKMLResponse( query_results, original_query) search_status = True elif response_type == "JSONP": search_results = self.ConstructJSONPResponse( query_results, original_query) search_status = True else: # This condition may not occur, # as response_type is either KML or JSONP self.logger.error("Invalid response type %s", response_type) return search_status, search_results def __del__(self): """Closes the connection pool created in __init__. """ self._pool.closeall()
if __name__ == '__main__': # if ((len(sys.argv)) != 2) and ((len(sys.argv)) != 3): # print(".py conf.json (genCache? - optional = Y/N)") # exit(-1) webapp = server() conf = { '/': { 'tools.sessions.on': True, 'tools.staticdir.root': os.path.abspath(os.getcwd()), 'tools.gzip.on': True, 'tools.gzip.mime_types': ['text/*', 'application/*'] }, '/public': { 'tools.staticdir.on': True, 'tools.staticdir.dir': './public' } } connPool = ThreadedConnectionPool( 1, 10, "dbname=balagan user=postgres password=nothing") # _createHier(1) # exit() cherrypy.server.max_request_body_size = 0 cherrypy.server.socket_host = '0.0.0.0' cherrypy.server.socket_port = 8000 cherrypy.quickstart(webapp, '/', conf) connPool.closeall()
class Database: """PostgreSQL Database class.""" def __init__(self, config): self.host = config.DATABASE_HOST self.username = config.DATABASE_USERNAME self.password = config.DATABASE_PASSWORD self.port = config.DATABASE_PORT self.dbname = config.DATABASE_NAME self.pool = None self.conns = {} # active connections from the pool def open_pool(self, minconns=1, maxconns=None): """Creates a connection pool to the PostgreSQL database""" if self.pool is None: maxconns = maxconns if maxconns is not None else minconns self.pool = ThreadedConnectionPool(minconns, maxconns, host=self.host, user=self.username, password=self.password, port=self.port, dbname=self.dbname, sslmode='disable') logger.success(f"Connection pool created to PostgreSQL database: {maxconns} connections available.") def close_pool(self): """Closes all connections in the pool""" if self.pool: self.pool.closeall() self.pool = None logger.success("All connections in the pool have been closed successfully.") @contextmanager def open(self, minconns=1, maxconns=None): """Context manager for managing a connection pool to the database. Can then instantiate a pool as: with Database.open() as pool: ... # use pool (pool.get_connection(key)) """ try: # Create connection pool self.open_pool(minconns, maxconns) yield self finally: # Close all connections in the pool self.close_pool() @contextmanager def connect(self, key=1): """Context manager wrapper allows to call it in this fashion: with database.open_pool() as pool: with pool.get_connection(key): ... # pool.send(key, *args) Args: key (int): key to identify the connection being opened. Required for proper book keeping. """ # Create connection pool self.get_connection(key=key) try: yield self except (Exception, KeyboardInterrupt) as e: logger.error(f"Error raised while managing connection from pool: {e}") finally: # Close all connections in the pool self.put_back_connection(key=key) def get_connection(self, key=1): """Connect to a Postgres database using available connection from pool. Args: key (int): key to identify the connection being opened. Required for proper book keeping. """ # If a pool has been created if self.pool: try: # If the specific connection hasn't been already opened if key not in self.conns: # Connect to PostgreSQL database conn = self.pool.getconn(key) # add to dictionary of active connections self.conns[key] = conn logger.success(f"Connection retrieved successfully: pool connection [{key}] now in use.") # perform connection Hello World self.on_conn_retrieval(key) else: logger.warning(f"Pool connection [{key}] is already in use by another client. Try a different key.") except psycopg2.pool.PoolError as error: logger.error(f"Error while retrieving connection from pool:\t{error}") sys.exit() except psycopg2.DatabaseError as error: logger.error(f"Error while connecting to PostgreSQL:\t{error}") sys.exit() else: logger.warning(f"No pool to the PostgreSQL database: cannot retrieve a connection. Try to .open() a pool.") def put_back_connection(self, key=1): """Put back connection to PostgreSQL database in the connection pool""" # If this specific connection has already been opened if key in self.conns: conn = self.conns[key] conn.reset() self.pool.putconn(conn, key) self.conns.pop(key) logger.success(f"Connection returned successfully: pool connection [{key}] now available again.") else: logger.warning(f"Pool connection [{key}] has never been opened: cannot put it back in the pool.") def on_conn_retrieval(self, key): """A small Hello World script to perform on retrieval of a PostgreSQL connection from the pool.""" # Return connection info from database self.connection_info(key=key) def send(self, query, args, success_msg='Query Success', error_msg="Query Error", cur_method=0, file=None, fetch_method=2, key=1): """Send a generic SQL query to the Database. Args: query (string or Composed): SQL command string (can be template with %s fields), as required by psycopg2 args (tuple or None): tuple of args to substitute in SQL query template, as required by psycopg2 success_msg (string): message to log on successful execution of the SQL query error_msg (string): message to log if error raised during execution of the SQL query cur_method (int): code to select which psycopg2 cursor execution method to use for the SQL query: 0: cursor.execute() 1: cursor.copy_expert() file (file): if cur_method == 1: a file-like object to read or write (according to sql). fetch_method (int): code to select which psycopg2 result retrieval method to use (fetch*()): 0: cur.fetchone() 2: cur.fetchall() key (int): key to identify the connection in the pool being used for the transaction Returns: records (psycopg2.extras.DictRow): list of query results (if any). Can be accessed as dictionaries. """ # If this specific connection has already been opened if key in self.conns: conn = self.conns[key] try: with conn.cursor(cursor_factory=DictCursor) as cur: query = cur.mogrify(query, args) if args is not None else cur.mogrify(query) # Execute query if cur_method == 0: cur.execute(query) elif cur_method == 1: cur.copy_expert(sql=query, file=file) # Fetch query results try: if fetch_method == 0: records = cur.fetchone() elif fetch_method == 2: records = cur.fetchall() # Handle SQL queries that don't return any results except psycopg2.ProgrammingError: records = [] pass conn.commit() # Display success message if cur.rowcount >= 0: success_msg += f": {cur.rowcount} rows affected." logger.success(success_msg) return records # dictionaries except (Exception, psycopg2.Error, psycopg2.DatabaseError) as e: conn.rollback() logger.error(error_msg + f":{e}. Transaction rolled-back.") # (Not sure if necessary) if conn has changed state while doing the above, update the entry in the dict finally: self.conns[key] = conn else: logger.warning(f"Pool connection [{key}] has never been opened: not available for transactions.") def select_rows(self, query, args=None, fetch_method=2, key=1): """Send a select SQL query to the Database. Expects returns.""" success_msg = "Data fetched successfully from PostgreSQL" error_msg = "Error while fetching data from PostgreSQL" records = self.send(query, args, success_msg, error_msg, fetch_method=fetch_method, key=key) return records def update_rows(self, query, args=None, key=1): """Run a SQL query to update rows in table.""" success_msg = "Database updated successfully" error_msg = "Error while updating data in PostgreSQL" self.send(query, args, success_msg, error_msg, key=key) def insert_rows(self, query, args=None, key=1): """Run a SQL query to insert rows in table.""" success_msg = "Record inserted successfully into database" error_msg = "Error executing SQL query" self.send(query, args, success_msg, error_msg, key=key) def listen_on_channel(self, channel, key=1): """Run a LISTEN SQL query""" query = "LISTEN " + channel + ";" success_msg = f"Successfully listening on channel {channel} for NOTIFYs" error_msg = "Error executing SQL LISTEN query" self.send(query, None, success_msg, error_msg, key=key) def connection_info(self, key=1): """Run a SELECT version() SQL query""" query = "SELECT version();" success_msg = f"PostgreSQL version fetched successfully" error_msg = "Error while fetching PostgreSQL version" record = self.send(query, None, success_msg, error_msg, fetch_method=0, key=key) # fetchone() logger.info(f"You are connected to - {record}") def create_table(self, query, args=None, key=1): """Run a SQL query to create a table.""" success_msg = "Table created successfully in PostgreSQL" error_msg = "Error while creating PostgreSQL table" self.send(query, args, success_msg, error_msg, key=key) def copy_table(self, query, file, replace=True, db_table=None, key=1): """Run a SQL query to copy a table to/from file.""" # Replace the table already existing in the database if replace: query_tmp = sql.SQL("TRUNCATE {};").format(sql.Identifier(db_table)) success_msg = "Table truncated successfully in PostgreSQL database" error_msg = "Error while truncating PostgreSQL table" self.send(query_tmp, None, success_msg, error_msg, key=key) # Copy the table from file success_msg = "Table copied successfully to/from PostgreSQL" error_msg = "Error while copying PostgreSQL table" self.send(query, None, success_msg, error_msg, cur_method=1, file=file, key=key) def copy_df(self, df, db_table, replace=True, key=1): """Run a SQL query to copy efficiently copy a pandas dataframe to a database table Inspired by: https://stackoverflow.com/questions/23103962/how-to-write-dataframe-to-postgres-table """ if key in self.conns: try: # Create headless csv from pandas dataframe io_file = io.StringIO() df.to_csv(io_file, sep='\t', header=False, index=False) io_file.seek(0) # Quickly create a table with correct number of columns / data types replacement_method = 'replace' if replace else 'append' engine = create_engine('postgresql+psycopg2://', creator=lambda: self.conns[key]) # create engine for it to work df.head(0).to_sql(db_table, engine, if_exists=replacement_method, index=False) # But then exploit postgreSQL COPY command instead of slow pandas .to_sql() sql_copy_expert = sql.SQL("COPY {} FROM STDIN WITH CSV DELIMITER '\t'").format(sql.Identifier(db_table)) self.copy_table(sql_copy_expert, file=io_file, replace=False, key=key) # need to keep the (header-only) table logger.success(f"DataFrame copied successfully to PostgreSQL table.") except (Exception, psycopg2.DatabaseError) as error: logger.error(f"Error while copying DataFrame to PostgreSQL table: {error}") else: logger.warning(f"Pool connection [{key}] has never been opened: cannot use it to copy Dataframe to database.")
class Database: def __init__(self, db_config, table_raw=None, max_connections=10): from psycopg2.pool import ThreadedConnectionPool self.table_raw = table_raw # Set default port is port is not set if not db_config.get('db_port'): db_config['db_port'] = 5432 self.pool = ThreadedConnectionPool( minconn=1, maxconn=max_connections, dsn= "dbname={db_name} user={db_user} host={db_host} password={db_pass} port={db_port}" # noqa: E501 .format(**db_config)) @contextmanager def getcursor(self, **kwargs): conn = self.pool.getconn() try: yield conn.cursor(**kwargs) conn.commit() except Exception: conn.rollback() raise finally: self.pool.putconn(conn) def close(self): self.pool.closeall() def insert(self, table, data_list, return_cols='id'): """ Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and n cols) for inserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ data_list = copy.deepcopy( data_list ) # Create deepcopy so the original list does not get modified # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] # Make sure data_list has content if len(data_list) == 0: # No need to continue return [] # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): raise ValueError("Data must be a list of dicts") # Make sure return_cols is a list if return_cols is None or len( return_cols) == 0 or return_cols[0] is None: return_cols = '' elif not isinstance(return_cols, list): return_cols = [return_cols] if len(return_cols) > 0: return_cols = 'RETURNING ' + ','.join(return_cols) try: with self.getcursor() as cur: query = "INSERT INTO {table} ({fields}) VALUES {values} {return_cols}"\ .format(table=table, fields='"{0}"'.format('", "'.join(data_list[0].keys())), values=','.join(['%s'] * len(data_list)), return_cols=return_cols, ) values = [] for row in [tuple(v.values()) for v in data_list]: values.append(_check_values(row)) query = cur.mogrify(query, values) cur.execute(query) try: return cur.fetchall() except Exception: return None except Exception: logger.debug("Error inserting data: {data}".format(data=data_list)) raise def upsert(self, table, data_list, on_conflict_fields, on_conflict_action='update', on_conflict_where=None, update_fields=None, return_cols='id'): """ Create a bulk upsert statement which is much faster (~6x in tests with 10k & 100k rows and n cols) for upserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ data_list = copy.deepcopy( data_list ) # Create deepcopy so the original list does not get modified # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] # Make sure data_list has content if len(data_list) == 0: # No need to continue return [] # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): raise ValueError("Data must be a list of dicts") # Make sure on_conflict_fields is a list if not isinstance(on_conflict_fields, list): on_conflict_fields = [on_conflict_fields] # Make sure on_conflict_fields has data if len(on_conflict_fields) == 0 or on_conflict_fields[0] is None: # No need to continue raise ValueError("Must pass in `on_conflict_fields` argument") # Support for partial index on table if on_conflict_where: on_conflict_where = f'WHERE {on_conflict_where}' else: on_conflict_where = '' # Make sure return_cols is a list if return_cols is None or len( return_cols) == 0 or return_cols[0] is None: return_cols = '' elif not isinstance(return_cols, list): return_cols = [return_cols] if len(return_cols) > 0: return_cols = 'RETURNING ' + ','.join(return_cols) # Make sure update_fields is a list/valid if on_conflict_action == 'update': if not isinstance(update_fields, list): update_fields = [update_fields] # If noting is passed in, set `update_fields` to all (data_list-on_conflict_fields) if len(update_fields) == 0 or update_fields[0] is None: update_fields = list( set(data_list[0].keys()) - set(on_conflict_fields)) # If update_fields is empty here that could only mean that all fields are set as conflict_fields if len(update_fields) == 0: raise ValueError( "Not all the fields can be `on_conflict_fields` when doing an update" ) # If everything is good to go with the update fields fields_update_tmp = [] for key in update_fields: fields_update_tmp.append('"{0}"="excluded"."{0}"'.format(key)) conflict_action_sql = 'UPDATE SET {update_fields}'\ .format(update_fields=', '.join(fields_update_tmp)) else: # Do nothing on conflict conflict_action_sql = 'NOTHING' try: with self.getcursor() as cur: query = """INSERT INTO {table} ({insert_fields}) VALUES {values} ON CONFLICT ({on_conflict_fields}) {on_conflict_where} DO {conflict_action_sql} {return_cols} """.format( table=table, insert_fields='"{0}"'.format('","'.join( data_list[0].keys())), values=','.join(['%s'] * len(data_list)), on_conflict_fields=','.join(on_conflict_fields), on_conflict_where=on_conflict_where, conflict_action_sql=conflict_action_sql, return_cols=return_cols, ) # Get all the values for each row and create a lists of lists values = [] for row in [list(v.values()) for v in data_list]: values.append(_check_values(row)) query = cur.mogrify(query, values) cur.execute(query) try: return cur.fetchall() except Exception: return None except Exception: logger.debug("Error upserting data: {data}".format(data=data_list)) raise def update(self, table, data_list, matched_field=None, return_cols='id'): """ Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and 4 cols) for inserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ data_list = copy.deepcopy( data_list ) # Create deepcopy so the original list does not get modified if matched_field is None: # Assume the id field logger.info("Matched field not defined, assuming the `id` field") matched_field = 'id' # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] if len(data_list) == 0: # No need to continue return [] # Make sure return_cols is a list if return_cols is None or len( return_cols) == 0 or return_cols[0] is None: return_cols = '' elif not isinstance(return_cols, list): return_cols = [return_cols] if len(return_cols) > 0: return_cols = 'RETURNING ' + ','.join(return_cols) # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): raise ValueError("Data must be a list of dicts") try: with self.getcursor() as cur: query_list = [] # TODO: change to return data from the database, not just what you passed in return_list = [] for row in data_list: if row.get(matched_field) is None: logger.debug( "Cannot update row. Missing field {field} in data {data}" .format(field=matched_field, data=row)) logger.error( "Cannot update row. Missing field {field} in data". format(field=matched_field)) continue # Pull matched_value from data to be updated and remove that key matched_value = row.get(matched_field) del row[matched_field] query = "UPDATE {table} SET {data} WHERE {matched_field}=%s {return_cols}"\ .format(table=table, data=','.join("%s=%%s" % u for u in row.keys()), matched_field=matched_field, return_cols=return_cols ) values = list(row.values()) values.append(matched_value) values = _check_values(values) query = cur.mogrify(query, values) query_list.append(query) return_list.append(matched_value) finial_query = b';'.join(query_list) cur.execute(finial_query) try: return cur.fetchall() except Exception: return None except Exception: logger.debug("Error updating data: {data}".format(data=data_list)) raise
class Database: """Jumbo's abstract PostgreSQL client manager. Allows to manage multiple independent connections to a PostgreSQL database and to handle arbitrary transactions between clients and database. Usage: .. code-block:: python # Initialize abstract database manager database = jumbo.database.Database() # Open a connection pool to the PostgreSQL database: with database.open() as pool: # Use a connection from the pool with pool.connect(): # Execute SQL query on the database pool.send(SQL_query) # context managers ensure connections are properly returned to the pool, and that the pool is properly closed. """ def __init__(self, env_path: Optional[str] = None) -> None: """Initializes manager to handle connections to a given PostgreSQL database. Args: env_path: path where to look for the jumbo.env configuration file. If not provided, looks for configuration file in the working directory of the script invoking this constructor. Attributes: config (jumbo.config.Config): jumbo's database configuration settings. pool (psycopg2.pool): connection pool to the PostgreSQL database. Initially closed. """ # Database configuration settings self.config = Config(env_path) # initialize a placeholder connection pool self.pool = ThreadedConnectionPool(0, 0) self.pool.closed = True # keep it closed on construction # Log configuration settings logger.info(f"Jumbo Connection Manager created:\n{self.config}") @contextmanager def open(self, minconns: int = 1, maxconns: Optional[int] = None) -> None: """Context manager opening a connection pool to the PostgreSQL database. The pool is automatically closed on exit and all connections are properly handled. Example: .. code-block:: python with self.open() as pool: # do something with the pool # pool is automatically closed here Args: minconns: minimum amount of available connections in the pool, created on startup. maxconns: maximum amount of available connections supported by the pool. Defaults to minconns. """ try: # Create connection pool self.open_pool(minconns, maxconns) yield self except OperationalError as e: logger.error(f"Error while opening connection pool: {e}") finally: # Close all connections in the pool self.close_pool() def open_pool(self, minconns: int = 1, maxconns: Optional[int] = None) -> None: """Initialises and opens a connection pool to the PostgreSQL database (psycopg2.pool.ThreadedConnectionPool). 'minconn' new connections are created immediately. The connection pool will support a maximum of about 'maxconn' connections. Args: minconns: minimum amount of available connections in the pool, created on startup. maxconns: maximum amount of available connections supported by the pool. Defaults to 'minconns'. """ # If the pool hasn't been opened yet if self.pool.closed: # initialize max number of supported connections maxconns = maxconns if maxconns is not None else minconns # create a connection pool based on jumbo's configuration settings self.pool = ThreadedConnectionPool( minconns, maxconns, host=self.config.DATABASE_HOST, user=self.config.DATABASE_USERNAME, password=self.config.DATABASE_PASSWORD, port=self.config.DATABASE_PORT, dbname=self.config.DATABASE_NAME, sslmode='disable') logger.info(f"Connection pool created to PostgreSQL database: " f"{maxconns} connections available.") def close_pool(self) -> None: """Closes all connections in the pool, making it unusable by clients. """ # If the pool hasn't been closed yet if not self.pool.closed: self.pool.closeall() logger.info("All connections in the pool have been closed " "successfully.") @contextmanager def connect(self, key: int = 1) -> None: """Context manager opening a connection to the PostgreSQL database using a connection [key] from the pool. The connection is automatically closed on exit and all transactions are properly handled. Example: .. code-block:: python # check-out connection from the pool with self.connect(key): # do something with the connection e.g. # self.send(sql_query, key) # connection is automatically closed here Args: key (optional): key to identify the connection being opened. Required for proper book keeping. """ # check-out an available connection from the pool self.get_connection(key=key) try: yield self except (Exception, KeyboardInterrupt) as e: logger.error(f"Error raised during connection [{key}] " f"transactions: {e}") finally: # return the connection to the pool self.put_back_connection(key=key) def get_connection(self, key: int = 1) -> None: """Connect to a Postgres database using an available connection from pool. The connection is assigned to 'key' on checkout. Args: key (optional): key to assign to the connection being opened. Required for proper book keeping. """ # If a pool has been opened if not self.pool.closed: try: # If the specific connection hasn't been already opened if key not in self.pool._used: # Connect to PostgreSQL database self.pool.getconn(key) logger.info(f"Connection retrieved successfully: pool " f"connection [{key}] now in use.") # perform handshake self.on_connection(key) else: logger.warning(f"Pool connection [{key}] is already in " f"use by another client. Try a different " f"key.") except PoolError as error: logger.error(f"Error while retrieving connection from " f"pool:\t{error}") sys.exit() else: logger.warning(f"No pool to the PostgreSQL database: cannot " f"retrieve a connection. Try to .open() a pool.") def on_connection(self, key: int = 1) -> None: """Client-database handshaking script to perform on retrieval of a PostgreSQL connection from the pool. Args: key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ # return database information info = self.connection_info(key=key) logger.info(f"You are connected to - {info}") def put_back_connection(self, key: int = 1) -> None: """Puts back a connection in the connection pool. Args: key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ # If this specific connection is under use if key in self.pool._used: # Reset connection to neutral state self.pool._used[key].reset() # Put back connection in the pool self.pool.putconn(self.pool._used[key], key) logger.info(f"Connection returned successfully: pool connection " f"[{key}] now available again.") else: logger.warning(f"Pool connection [{key}] has never been opened: " f"cannot put it back in the pool.") def send(self, query: Union[str, sql.Composed], subs: Optional[Tuple[str, ...]] = None, cur_method: int = 0, file: Optional[IO] = None, fetch_method: int = 2, key: int = 1) -> Union[DictRow, None]: """Sends an arbitrary PostgreSQL query to the PostgreSQL database. Transactions are auto-committed on execution. Example: .. code-block:: python # A simple query with no substitutions query = 'SELECT * FROM table_name;' results = self.send(query) # A more complex query with dynamic substitutions query = 'INSERT INTO table_name (column_name, another_column_name) VALUES (%s, $s);' subs = (value, another_value) results = self.send(query, subs) Args: query: PostgreSQL command string (can be template with psycopg2 %s fields). subs (optional): tuple of values to substitute in SQL query template (cf. psycopg2 %s formatting) cur_method (optional): code to select which psycopg2 cursor execution method to use for the SQL query: 0: cursor.execute() 1: cursor.copy_expert() file (optional): file-like object to read or write to (only relevant if cur_method:1). fetch_method (optional): code to select which psycopg2 result retrieval method to use (fetch*()): 0: cur.fetchone() 2: cur.fetchall() key (optional): key of the pool connection being used in the transaction. Defaults to [1]. Returns: list of query results (if any). Can be accessed as dictionaries. """ # If this specific connection has already been opened if key in self.pool._used: try: # try running a transaction with self.pool._used[key].cursor( cursor_factory=DictCursor) as cur: # Bind arguments to query string (if present) if subs is not None: query = cur.mogrify(query, subs) else: query = cur.mogrify(query) # Execute query if cur_method == 0: cur.execute(query) elif cur_method == 1: cur.copy_expert(sql=query, file=file) # Fetch query results try: if fetch_method == 0: records = cur.fetchone() elif fetch_method == 2: records = cur.fetchall() # Handle SQL queries that don't return any results # (INSERT, UPDATE, etc...) except ProgrammingError: records = [] pass # Commit transaction self.pool._used[key].commit() # Display success message (and shorten query if too long) s_query = query if len(query) > 78: s_query = (str(query[:75]) + '...') success_msg = f"Successfully sent: {s_query} " if cur.rowcount >= 0: success_msg += f": {cur.rowcount} rows affected." logger.info(success_msg) return records # dictionaries except (Exception, Error, DatabaseError) as e: # Rollback transaction if any problem self.pool._used[key].rollback() logger.error(f"Error while sending query {query}:{e}. " f"Transaction rolled-back.") else: logger.warning(f"Pool connection [{key}] has never been opened: " f"not available for transactions.") def listen_on_channel(self, channel_name: str, key: int = 1) -> None: """Subscribes to a PostgreSQL notification channel by listening for NOTIFYs. .. code-block:: postgresql -- Command executed: LISTEN channel_name; Args: channel_name: channel on which to LISTEN. PostgreSQL database should be configured to send NOTIFYs on this channel. key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ query = "LISTEN " + channel_name + ";" self.send(query, key=key) def connection_info(self, key: int = 1) -> DictRow: """Fetches PostgreSQL database version. .. code-block:: postgresql -- Command executed: SELECT version(); Args: key (optional): key of the pool connection being used in the transaction. Defaults to [1]. Returns: query result. Contains PostgreSQL database version information. """ query = "SELECT version();" info = self.send(query, fetch_method=0, key=key) # fetchone() return info def copy_to_table(self, query: sql.Composed, file: IO, db_table: str, schema: str = None, replace: bool = True, key: int = 1) -> None: """Utility wrapper to send a SQL query to copy data to database table. Allows to replace table if it already exists in the database. .. code-block:: postgresql -- Command type expected: COPY table_name [ ( column_name [, ...] ) ] FROM STDIN [ [ WITH ] ( option [, ...] ) ] -- Ancillary command pre-executed: TRUNCATE table_name; Example: .. code-block:: python # Copy csv data from file to a table in the database query = "COPY table_name FROM STDIN WITH CSV DELIMITER '\\t'" results = self.copy_to_table(query, file="C:\\data.csv", db_table='table_name') Args: query: PostgreSQL COPY command string. file: absolute path to file-like object to read data from. db_table: the name (not schema-qualified) of an existing database table. schema (optional): schema to which ``db_table`` belongs. If ``None``, use default schema. replace (optional): replaces table contents if True. Appends data to table contents otherwise. key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ # Replace the table already existing in the database if replace: # schema-qualify table name if needed identifier = sql.Identifier(schema, db_table) if schema is not None else sql.Identifier(db_table) # pass table name dynamically to query query_tmp = sql.SQL("TRUNCATE {};").format(identifier) self.send(query_tmp, key=key) # Copy the table from file (cur_method:1 = cur.copy_expert) self.send(query, cur_method=1, file=file, key=key) def copy_df(self, df: pd.DataFrame, db_table: str, schema: str = None, replace: bool = False, key: int = 1) -> None: """Utility wrapper to efficiently copy a pandas.DataFrame to a PostgreSQL database table. This method is faster than panda's native *.to_sql()* method and exploits PostgreSQL COPY TO command. Provides a useful mean of saving results from a pandas-centred data analysis pipeline directly to the database. Args: df: dataframe to be copied. db_table: the name (not schema-qualified) of the table to write to. schema (optional): schema to which ``db_table`` belongs. If ``None``, use default schema. replace (optional): replaces table contents if True. Appends data to table contents otherwise. key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ if key in self.pool._used: try: # Create headless csv from pandas dataframe io_file = io.StringIO() df.to_csv(io_file, sep='\t', header=False, index=False) io_file.seek(0) # Quickly create a table with correct number of columns / data # types Unfortunately we will need to quickly build a # sqlalchemy engine for this hack to work replacement_method = 'replace' if replace else 'append' engine = create_engine('postgresql+psycopg2://', creator=lambda: self.pool._used[key]) df.head(0).to_sql(name=db_table, con=engine, schema=schema, if_exists=replacement_method, index=False) # But then exploit postgreSQL COPY command instead of slow # pandas .to_sql(). Note that replace is set to false in # copy_table as we want to preserve the header table created # above # schema-qualify table name if needed identifier = sql.Identifier(schema, db_table) if schema is not None else sql.Identifier(db_table) sql_copy_expert = sql.SQL( "COPY {} FROM STDIN WITH CSV DELIMITER '\t'").format( identifier) self.copy_to_table(sql_copy_expert, file=io_file, db_table=db_table, schema=schema, replace=False, key=key) logger.info(f"DataFrame copied successfully to PostgreSQL " f"table.") except (Exception, DatabaseError) as error: logger.error(f"Error while copying DataFrame to PostgreSQL " f"table: {error}") else: logger.warning(f"Pool connection [{key}] has never been opened: " f"cannot use it to copy Dataframe to database.")
class BatchCopy(object): def __init__(self, dw_conf, event_path, worker_id, table_names, redis_conf, logger, sep='\x02'): self._tables = {} self._pool = ThreadedConnectionPool(1, 2, host=dw_conf.get("host"), port=dw_conf.get("port"), database=dw_conf.get("database"), user=dw_conf.get("user"), password=dw_conf.get("password")) self._redis = redis_conf self._logger = logger self._sep = sep cur = self._pool.getconn().cursor(cursor_factory=RealDictCursor) for table_name in table_names: worker_dir = "%s/%s" % (event_path, worker_id) if not os.path.exists(worker_dir): os.makedirs(worker_dir) file_path = "%s/%s.udw" % (worker_dir, table_name) self._tables[table_name] = { "columns": {}, "path": file_path, "file": open(file_path, 'a+'), "data": [] } sql = "SELECT ordinal_position, column_name FROM information_schema.columns " \ "WHERE table_name = '%s'" % table_name cur.execute(sql) columns = cur.fetchall() if len(columns) == 0: self._logger.error(table_name) continue for column in columns: self._tables[table_name]["columns"][ column["column_name"]] = column["ordinal_position"] cur.close() def json2copy(self, event): values = [] hash_values = {} for k, v in event.value.items(): if type(v) == int: hash_values[self._tables[event.key]["columns"][k]] = str(v) if type(v) == str: hash_values[self._tables[event.key]["columns"][k]] = v if type(v) == list or type(v) == dict: hash_values[self._tables[event.key]["columns"] [k]] = "'%s'" % str(v) range_stop = len(hash_values) + 1 for i in range(1, range_stop): values.append(hash_values[i]) return self._sep.join(values) def writelines(self, events): for event in events: try: row = self.json2copy(event) except Exception as e: self._logger.error(e) self._logger.error(self._tables[event.key]) else: self._tables[event.key]["data"].append(row + '\n') for v in self._tables.values(): if not isinstance(v, dict): continue v["file"].writelines(v["data"]) del v["data"][:] def redis_sieve(self, db, values): rc = RedisConnector(self._redis, db) values.seek(0) values = values.read().split('\n')[:-1] return rc.multi_add(values) def copy_sink(self): for k, v in self._tables.items(): result = self.redis_sieve(k, v['file']) if isinstance(result, list): result = '\n'.join(result) self._logger.info('duplicate data') v['file'].close() os.remove(v['path']) v['file'] = open(v['path'], 'a+') v['file'].write(result) # 无重复时 v["file"].seek(0) try: conn = self._pool.getconn() cur = conn.cursor(cursor_factory=RealDictCursor) cur.copy_from(v["file"], k, sep=self._sep) conn.commit() # 数据类型错误或缺少字段时为DataError except psycopg2.DataError as e: self._logger.error(e.message) v["file"].seek(0) with open( '/data/log/unresolved_data/cleaned#' + k + '#' + datetime.now().strftime("%Y%m%d%H%M%S"), 'w') as err_log: err_log.write(v["file"].read()) # 数据库响应速度慢,报错返回失败,返回为空 except psycopg2.DatabaseError as e: self._logger.error(e) v["file"].seek(0) with open( '/data/log/unresolved_data/cleaned#' + k + '#' + datetime.now().strftime("%Y%m%d%H%M%S"), 'w') as interrupted: interrupted.write(v["file"].read()) sleep(300) finally: cur.close() self._pool.putconn(conn) self.flush_copy() def flush_copy(self): for v in self._tables.values(): if not isinstance(v, dict): continue v["file"].close() os.remove(v["path"]) v["file"] = open(v["path"], 'a+') def close(self): for v in self._tables.values(): if not isinstance(v, dict): continue v["file"].close() os.remove(v["path"]) if not self._pool.closed: self._pool.closeall()
class ExampleSearch(object): """Class for performing the Example search. Example search is the neighborhood search that demonstrates how to construct and query a spatial database based on URL search string, extract geometries from the result, associate various styles with them and return the response back to the client Valid Inputs are: q=pacific heights neighborhood=pacific heights """ def __init__(self): """Inits ExampleSearch. Initializes the logger "ge_search". Initializes templates for kml, json, placemark templates for the KML/JSONP output. Initializes parameters for establishing a connection to the database. """ self.utils = utils.SearchUtils() constants = geconstants.Constants() configs = self.utils.GetConfigs( os.path.join(geconstants.SEARCH_CONFIGS_DIR, "ExampleSearch.conf")) style_template = self.utils.style_template self._jsonp_call = self.utils.jsonp_functioncall self._geom = """ <name>%s</name> <styleUrl>%s</styleUrl> <Snippet>%s</Snippet> <description>%s</description> %s\ """ self._json_geom = """ { "name": "%s", "Snippet": "%s", "description": "%s", %s }\ """ self._placemark_template = self.utils.placemark_template self._kml_template = self.utils.kml_template self._json_template = self.utils.json_template self._json_placemark_template = self.utils.json_placemark_template self._example_query_template = (Template(constants.example_query)) self.logger = self.utils.logger self._user = configs.get("user") self._hostname = configs.get("host") self._port = configs.get("port") self._database = configs.get("databasename") if not self._database: self._database = constants.defaults.get("example.database") self._pool = ThreadedConnectionPool( int(configs.get("minimumconnectionpoolsize")), int(configs.get("maximumconnectionpoolsize")), database=self._database, user=self._user, host=self._hostname, port=int(self._port)) self._style = style_template.substitute( balloonBgColor=configs.get("balloonstyle.bgcolor"), balloonTextColor=configs.get("balloonstyle.textcolor"), balloonText=configs.get("balloonstyle.text"), iconStyleScale=configs.get("iconstyle.scale"), iconStyleHref=configs.get("iconstyle.href"), lineStyleColor=configs.get("linestyle.color"), lineStyleWidth=configs.get("linestyle.width"), polyStyleColor=configs.get("polystyle.color"), polyStyleColorMode=configs.get("polystyle.colormode"), polyStyleFill=configs.get("polystyle.fill"), polyStyleOutline=configs.get("polystyle.outline"), listStyleHref=configs.get("iconstyle.href")) def RunPGSQLQuery(self, query, params): """Submits the query to the database and returns tuples. Note: variables placeholder must always be %s in query. Warning: NEVER use Python string concatenation (+) or string parameters interpolation (%) to pass variables to a SQL query string. e.g. SELECT vs_url FROM vs_table WHERE vs_name = 'default_ge'; query = "SELECT vs_url FROM vs_table WHERE vs_name = %s" parameters = ["default_ge"] Args: query: SQL SELECT statement. params: sequence of parameters to populate into placeholders. Returns: Results as list of tuples (rows of fields). Raises: psycopg2.Error/Warning in case of error. """ con = None cursor = None query_results = [] query_status = False self.logger.debug( "Querying the database %s, at port %s, as user %s on" "hostname %s" % (self._database, self._port, self._user, self._hostname)) try: con = self._pool.getconn() if con: cursor = con.cursor() cursor.execute(query, params) for row in cursor: if len(row) == 1: query_results.append(row[0]) else: query_results.append(row) query_status = True except psycopg2.pool.PoolError as e: self.logger.error("Exception while querying the database %s, %s", self._database, e) raise exceptions.PoolConnectionException( "Pool Error - Unable to get a connection from the pool.") except psycopg2.Error as e: self.logger.error("Exception while querying the database %s, %s", self._database, e) finally: if con: self._pool.putconn(con) return query_status, query_results def RunExampleSearch(self, search_query, response_type): """Performs a query search on the 'san_francisco_neighborhoods' table. Args: search_query: the query to be searched, in smallcase. response_type: Response type can be KML or JSONP, depending on the client Returns: tuple containing total_example_results: Total number of rows returned from querying the database. example_results: Query results as a list """ example_results = [] params = ["%" + entry + "%" for entry in search_query.split(",")] accum_func = self.utils.GetAccumFunc(response_type) example_query = self._example_query_template.substitute( FUNC=accum_func) query_status, query_results = self.RunPGSQLQuery(example_query, params) total_example_results = len(query_results) if query_status: for entry in xrange(total_example_results): results = {} name = query_results[entry][4] snippet = query_results[entry][3] styleurl = "#placemark_label" description = ("The total area in decimal degrees of " + query_results[entry][4] + " is: " + str(query_results[entry][1]) + "<![CDATA[<br/>]]>") description += ("The total perimeter in decimal degrees of " + query_results[entry][4] + " is: " + str(query_results[entry][2])) geom = str(query_results[entry][0]) results["name"] = name results["snippet"] = snippet results["styleurl"] = styleurl results["description"] = description results["geom"] = geom results["geom_type"] = str(query_results[entry][5]) example_results.append(results) return total_example_results, example_results def ConstructKMLResponse(self, search_results, original_query): """Prepares KML response. KML response has the below format: <kml> <Folder> <name/> <StyleURL> --- </StyleURL> <Point> <coordinates/> </Point> </Folder> </kml> Args: search_results: Query results from the searchexample database original_query: Search query as entered by the user Returns: kml_response: KML formatted response """ search_placemarks = "" kml_response = "" lookat_info = "" set_first_element_lookat = True # folder name should include the query parameter(q) if 'displayKeys' # is present in the URL otherwise not. if self.display_keys_string: folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) else: folder_name = ("Grouped results:<![CDATA[<br/>]]> (%s)" % (str(len(search_results)))) fly_to_first_element = str(self.fly_to_first_element).lower() == "true" for result in search_results: geom = "" placemark = "" geom = self._geom % (result["name"], result["styleurl"], result["snippet"], result["description"], result["geom"]) # Add <LookAt> for POINT geometric types only. # TODO: Check if <LookAt> can be added for # LINESTRING and POLYGON types. if result["geom_type"] != "POINT": set_first_element_lookat = False if fly_to_first_element and set_first_element_lookat: lookat_info = self.utils.GetLookAtInfo(result["geom"]) set_first_element_lookat = False placemark = self._placemark_template.substitute(geom=geom) search_placemarks += placemark kml_response = self._kml_template.substitute( foldername=folder_name, style=self._style, lookat=lookat_info, placemark=search_placemarks) self.logger.info("KML response successfully formatted") return kml_response def ConstructJSONPResponse(self, search_results, original_query): """Prepares JSONP response. { "Folder": { "name": "Latitude X Longitude Y", "Placemark": { "Point": { "coordinates": "X,Y" } } } } Args: search_results: Query results from the searchexample table original_query: Search query as entered by the user Returns: jsonp_response: JSONP formatted response """ search_placemarks = "" search_geoms = "" geoms = "" json_response = "" jsonp_response = "" folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) for count, result in enumerate(search_results): geom = "" geom = self._json_geom % (result["name"], result["snippet"], result["description"], result["geom"][1:-1]) if count < (len(search_results) - 1): geom += "," geoms += geom if len(search_results) == 1: search_geoms = geoms else: search_geoms = "[" + geoms + "]" search_placemarks = self._json_placemark_template.substitute( geom=search_geoms) json_response = self._json_template.substitute( foldername=folder_name, json_placemark=search_placemarks) # Escape single quotes from json_response. json_response = json_response.replace("'", "\\'") jsonp_response = self._jsonp_call % (self.f_callback, json_response) self.logger.info("JSONP response successfully formatted") return jsonp_response def HandleSearchRequest(self, environ): """Fetches the search tokens from form and performs the example search. Args: environ: A list of environment variables as supplied by the WSGI interface to the example search application interface. Returns: search_results: A KML/JSONP formatted string which contains search results. """ search_results = "" search_status = False # Fetch all the attributes provided by the user. parameters = self.utils.GetParameters(environ) response_type = self.utils.GetResponseType(environ) # Retrieve the function call back name for JSONP response. self.f_callback = self.utils.GetCallback(parameters) original_query = self.utils.GetValue(parameters, "q") # Fetch additional query parameters 'flyToFirstElement' and # 'displayKeys' from URL. self.fly_to_first_element = self.utils.GetValue( parameters, "flyToFirstElement") self.display_keys_string = self.utils.GetValue(parameters, "displayKeys") if not original_query: # Extract 'neighborhood' parameter from URL try: form = cgi.FieldStorage(fp=environ["wsgi.input"], environ=environ) original_query = form.getvalue("neighborhood") except AttributeError as e: self.logger.debug("Error in neighborhood query %s" % e) if original_query: (search_status, search_results) = self.DoSearch(original_query, response_type) else: self.logger.debug("Empty or incorrect search query received") if not search_status: folder_name = "No results were returned." search_results = self.utils.NoSearchResults( folder_name, self._style, response_type, self.f_callback) return (search_results, response_type) def DoSearch(self, original_query, response_type): """Performs the example search and returns the results. Args: original_query: A string containing the search query as entered by the user. response_type: Response type can be KML or JSONP, depending on the client. Returns: tuple containing search_status: Whether search could be performed. search_results: A KML/JSONP formatted string which contains search results. """ search_status = False search_results = "" query_results = "" total_results = 0 search_query = original_query.strip().lower() if len(search_query.split(",")) > 2: self.logger.warning("Extra search parameters ignored:%s" % (",".join(search_query.split(",")[2:]))) search_query = ",".join(search_query.split(",")[:2]) original_query = ",".join(original_query.split(",")[:2]) total_results, query_results = self.RunExampleSearch( search_query, response_type) self.logger.info("example search returned %s results" % total_results) if total_results > 0: if response_type == "KML": search_results = self.ConstructKMLResponse( query_results, original_query) search_status = True elif response_type == "JSONP": search_results = self.ConstructJSONPResponse( query_results, original_query) search_status = True else: # This condition may not occur, # as response_type is either KML or JSONP self.logger.debug("Invalid response type %s" % response_type) return search_status, search_results def __del__(self): """Closes the connection pool created in __init__. """ self._pool.closeall()
dbname = cp.get('database', 'dbname') dbuser = cp.get('database', 'user') dbhost = cp.get('database', 'host') dbpass = cp.get('database', 'dbpass') dbpoolSize = cp.get('database', 'dbpoolSize') except ConfigParser.NoOptionError, e: print "TBDB.cfg: missing parameter" exit(1) # Create DB connection pool dbpool = ThreadedConnectionPool( 2, int(dbpoolSize), "dbname='%s' user='******' host='%s' password='******'" % (dbname, dbuser, dbhost, dbpass)) # Starts each channel/thread for line in file(nodesCfgFile): print line if line[:1] != '#' and len(line.strip()) > 0: moteId, local_port, dev_addr, dev_port = line.split() #settings.append((int(moteId),int(local_port),dev_addr,int(dev_port))) forwarder('', int(local_port), dev_addr, int(dev_port), int(moteId)) try: asyncore.loop() except KeyboardInterrupt, e: print e except asyncore.ExitNow, e: print e # close all DB Pool Connection print "Closing DB connection pool" dbpool.closeall()
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} flavor = Flavor(ilike=True) def __new__(cls, database_name='template1'): if database_name in cls._databases: return cls._databases[database_name] return DatabaseInterface.__new__(cls, database_name=database_name) def __init__(self, database_name='template1'): super(Database, self).__init__(database_name=database_name) self._databases.setdefault(database_name, self) def connect(self): if self._connpool is not None: return self logger.info('connect to "%s"', self.database_name) uri = parse_uri(config.get('database', 'uri')) assert uri.scheme == 'postgresql' host = uri.hostname and "host=%s" % uri.hostname or '' port = uri.port and "port=%s" % uri.port or '' name = "dbname=%s" % self.database_name user = uri.username and "user=%s" % uri.username or '' password = ("password=%s" % urllib.unquote_plus(uri.password) if uri.password else '') minconn = config.getint('database', 'minconn', default=1) maxconn = config.getint('database', 'maxconn', default=64) dsn = '%s %s %s %s %s' % (host, port, name, user, password) self._connpool = ThreadedConnectionPool(minconn, maxconn, dsn) return self def cursor(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() conn = self._connpool.getconn() if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) cursor = Cursor(self._connpool, conn, self) if readonly: cursor.execute('SET TRANSACTION READ ONLY') return cursor def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None @classmethod def create(cls, cursor, database_name): cursor.execute('CREATE DATABASE "' + database_name + '" ' 'TEMPLATE template0 ENCODING \'unicode\'') cls._list_cache = None @classmethod def drop(cls, cursor, database_name): cursor.execute('DROP DATABASE "' + database_name + '"') cls._list_cache = None def get_version(self, cursor): if self.database_name not in self._version_cache: cursor.execute('SELECT version()') version, = cursor.fetchone() self._version_cache[self.database_name] = tuple(map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.database_name] @staticmethod def dump(database_name): from trytond.tools import exec_command_pipe cmd = ['pg_dump', '--format=c', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: # if db_password is set in configuration we should pass # an environment variable PGPASSWORD to our subprocess # see libpg documentation env['PGPASSWORD'] = uri.password cmd.append(database_name) pipe = exec_command_pipe(*tuple(cmd), env=env) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception('Couldn\'t dump database!') return data @staticmethod def restore(database_name, data): from trytond.tools import exec_command_pipe database = Database().connect() cursor = database.cursor(autocommit=True) database.create(cursor, database_name) cursor.commit() cursor.close() database.close() cmd = ['pg_restore', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: env['PGPASSWORD'] = uri.password cmd.append('--dbname=' + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam() with open(tmpfile, 'wb') as fp: fp.write(data) args2 = list(args2) args2.append(' ' + tmpfile) args2 = tuple(args2) pipe = exec_command_pipe(*args2, env=env) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception('Couldn\'t restore database') database = Database(database_name).connect() cursor = database.cursor() if not cursor.test(): cursor.close() database.close() raise Exception('Couldn\'t restore database!') cursor.close() database.close() Database._list_cache = None return True @staticmethod def list(cursor): now = time.time() timeout = config.getint('session', 'timeout') res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res cursor.execute('SELECT datname FROM pg_database ' 'WHERE datistemplate = false ORDER BY datname') res = [] for db_name, in cursor.fetchall(): try: database = Database(db_name).connect() except Exception: continue cursor2 = database.cursor() if cursor2.test(): res.append(db_name) cursor2.close(close=True) else: cursor2.close(close=True) database.close() Database._list_cache = res Database._list_cache_timestamp = now return res @staticmethod def init(cursor): from trytond.modules import get_module_info sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ('ir', 'res'): state = 'uninstalled' if module in ('ir', 'res'): state = 'to install' info = get_module_info(module) cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')') module_id = cursor.fetchone()[0] cursor.execute('INSERT INTO ir_module ' '(id, create_uid, create_date, name, state) ' 'VALUES (%s, %s, now(), %s, %s)', (module_id, 0, module, state)) for dependency in info.get('depends', []): cursor.execute('INSERT INTO ir_module_dependency ' '(create_uid, create_date, module, name) ' 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} def __new__(cls, database_name='template1'): if database_name in cls._databases: return cls._databases[database_name] return DatabaseInterface.__new__(cls, database_name=database_name) def __init__(self, database_name='template1'): super(Database, self).__init__(database_name=database_name) self._databases.setdefault(database_name, self) def connect(self): if self._connpool is not None: return self logger = logging.getLogger('database') logger.info('connect to "%s"' % self.database_name) host = CONFIG['db_host'] and "host=%s" % CONFIG['db_host'] or '' port = CONFIG['db_port'] and "port=%s" % CONFIG['db_port'] or '' name = "dbname=%s" % self.database_name user = CONFIG['db_user'] and "user=%s" % CONFIG['db_user'] or '' password = CONFIG['db_password'] \ and "password=%s" % CONFIG['db_password'] or '' minconn = int(CONFIG['db_minconn']) or 1 maxconn = int(CONFIG['db_maxconn']) or 64 dsn = '%s %s %s %s %s' % (host, port, name, user, password) self._connpool = ThreadedConnectionPool(minconn, maxconn, dsn) return self def cursor(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() conn = self._connpool.getconn() if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) cursor = Cursor(self._connpool, conn, self) # TODO change for set_session if readonly: cursor.execute('SET TRANSACTION READ ONLY') return cursor def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None def create(self, cursor, database_name): cursor.execute('CREATE DATABASE "' + database_name + '" ' \ 'TEMPLATE template0 ENCODING \'unicode\'') Database._list_cache = None def drop(self, cursor, database_name): cursor.execute('DROP DATABASE "' + database_name + '"') Database._list_cache = None def get_version(self, cursor): if self.database_name not in self._version_cache: cursor.execute('SELECT version()') version, = cursor.fetchone() self._version_cache[self.database_name] = tuple(map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.database_name] @staticmethod def dump(database_name): from trytond.tools import exec_pg_command_pipe cmd = ['pg_dump', '--format=c', '--no-owner'] if CONFIG['db_user']: cmd.append('--username='******'db_user']) if CONFIG['db_host']: cmd.append('--host=' + CONFIG['db_host']) if CONFIG['db_port']: cmd.append('--port=' + CONFIG['db_port']) cmd.append(database_name) pipe = exec_pg_command_pipe(*tuple(cmd)) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception('Couldn\'t dump database!') return data @staticmethod def restore(database_name, data): from trytond.tools import exec_pg_command_pipe database = Database().connect() cursor = database.cursor(autocommit=True) database.create(cursor, database_name) cursor.commit() cursor.close() cmd = ['pg_restore', '--no-owner'] if CONFIG['db_user']: cmd.append('--username='******'db_user']) if CONFIG['db_host']: cmd.append('--host=' + CONFIG['db_host']) if CONFIG['db_port']: cmd.append('--port=' + CONFIG['db_port']) cmd.append('--dbname=' + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam() with open(tmpfile, 'wb') as fp: fp.write(data) args2 = list(args2) args2.append(' ' + tmpfile) args2 = tuple(args2) pipe = exec_pg_command_pipe(*args2) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception('Couldn\'t restore database') database = Database(database_name).connect() cursor = database.cursor() if not cursor.test(): cursor.close() database.close() raise Exception('Couldn\'t restore database!') cursor.close() database.close() Database._list_cache = None return True @staticmethod def list(cursor): now = time.time() timeout = int(CONFIG['session_timeout']) res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res db_user = CONFIG['db_user'] if not db_user and os.name == 'posix': db_user = pwd.getpwuid(os.getuid())[0] if not db_user: cursor.execute("SELECT usename " \ "FROM pg_user " \ "WHERE usesysid = (" \ "SELECT datdba " \ "FROM pg_database " \ "WHERE datname = %s)", (CONFIG["db_name"],)) res = cursor.fetchone() db_user = res and res[0] if db_user: cursor.execute("SELECT datname " \ "FROM pg_database " \ "WHERE datdba = (" \ "SELECT usesysid " \ "FROM pg_user " \ "WHERE usename=%s) " \ "AND datname not in " \ "('template0', 'template1', 'postgres') " \ "ORDER BY datname", (db_user,)) else: cursor.execute("SELECT datname " \ "FROM pg_database " \ "WHERE datname not in " \ "('template0', 'template1','postgres') " \ "ORDER BY datname") res = [] for db_name, in cursor.fetchall(): db_name = db_name.encode('utf-8') try: database = Database(db_name).connect() except Exception: continue cursor2 = database.cursor() if cursor2.test(): res.append(db_name) cursor2.close(close=True) else: cursor2.close(close=True) database.close() Database._list_cache = res Database._list_cache_timestamp = now return res @staticmethod def init(cursor): from trytond.tools import safe_eval sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for i in ('ir', 'res', 'webdav'): root_path = os.path.join(os.path.dirname(__file__), '..', '..') tryton_file = os.path.join(root_path, i, '__tryton__.py') with open(tryton_file) as fp: info = safe_eval(fp.read()) active = info.get('active', False) if active: state = 'to install' else: state = 'uninstalled' cursor.execute('SELECT NEXTVAL(\'ir_module_module_id_seq\')') module_id = cursor.fetchone()[0] cursor.execute('INSERT INTO ir_module_module ' \ '(id, create_uid, create_date, author, website, name, ' \ 'shortdesc, description, state) ' \ 'VALUES (%s, %s, now(), %s, %s, %s, %s, %s, %s)', (module_id, 0, info.get('author', ''), info.get('website', ''), i, info.get('name', False), info.get('description', ''), state)) dependencies = info.get('depends', []) for dependency in dependencies: cursor.execute('INSERT INTO ir_module_module_dependency ' \ '(create_uid, create_date, module, name) ' \ 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))
class PgPool: def __init__(self): self.host, self.port = None, None self.database, self.pool = None, None self.user, self.passwd = None, None self.pool = None def create_pool(self, conn_dict): """ Create a connection pool :param conn_dict: connection params dictionary :type conn_dict: dict """ if conn_dict["Host"] is None: self.host = 'localhost' else: self.host = conn_dict["Host"] if conn_dict["Port"] is None: self.port = '5432' else: self.port = conn_dict["Port"] self.database = conn_dict["Database"] self.user = conn_dict["User"] self.passwd = conn_dict["Password"] conn_params = "host='{host}' dbname='{db}' user='******' password='******' port='{port}'".format( host=self.host, db=self.database, user=self.user, passwd=self.passwd, port=self.port) try: self.pool = ThreadedConnectionPool(1, 50, conn_params) except Exception as e: print(e.message) def get_conn(self): """ Get a connection from pool and return connection and cursor :return: conn, cursor """ try: conn = self.pool.getconn() cursor = conn.cursor(cursor_factory=psycopg2.extras.DictCursor) return conn, cursor except Exception as e: print(e.message) return None, None @staticmethod def execute_query(cursor, query, params): """ Execute a query on database :param cursor: cursor object :param query: database query :type query: str :param params: query parameters :type params: tuple :return: query results or bool """ try: if query.split()[0].lower() == 'select': cursor.execute(query, params) return cursor.fetchall() else: return cursor.execute(query, params) except Exception as e: print(e.message) return False # commit changes to db permanently @staticmethod def commit_changes(conn): """ Commit changes to the database permanently :param conn: connection object :return: bool """ try: return conn.commit() except Exception as e: print(e.message) return False def put_conn(self, conn): """ Put connection back to the pool :param conn: connection object :return: bool """ try: return self.pool.putconn(conn) except Exception as e: print(e.message) return False def close_pool(self): """ Closes connection pool :return: bool """ try: return self.pool.closeall() except Exception as e: print(e.message) return False
def handle(self, *args, **options): self.mturk_email = getattr(settings, 'MTURK_AUTH_EMAIL', None) self.mturk_password = getattr(settings, 'MTURK_AUTH_PASSWORD', None) _start_time = time.time() pid = Pid('mturk_crawler', True) log.info('crawler started: %s;;%s', args, options) if options.get('mturk_email'): self.mturk_email = options['mturk_email'] if options.get('mturk_password'): self.mturk_password = options['mturk_password'] if options.get('logconf', None): self.setup_logging(options['logconf']) if options.get('debug', False): self.setup_debug() print 'Current proccess pid: %s' % pid.actual_pid print ('To debug, type: python -c "import os,signal; ' 'os.kill(%s, signal.SIGUSR1)"\n') % pid.actual_pid self.maxworkers = options['workers'] if self.maxworkers > 9: # If you want to remote this limit, don't forget to change dbpool # object maximum number of connections. Each worker should fetch # 10 hitgroups and spawn single task for every one of them, that # will get private connection instance. So for 9 workers it's # already 9x10 = 90 connections required # # Also, for too many workers, amazon isn't returning valid data # and retrying takes much longer than using smaller amount of # workers sys.exit('Too many workers (more than 9). Quit.') start_time = datetime.datetime.now() hits_available = tasks.hits_mainpage_total() groups_available = tasks.hits_groups_total() # create crawl object that will be filled with data later crawl = Crawl.objects.create( start_time=start_time, end_time=start_time, success=True, hits_available=hits_available, hits_downloaded=0, groups_available=groups_available, groups_downloaded=groups_available) log.debug('fresh crawl object created: %s', crawl.id) # fetch those requester profiles so we could decide if their hitgroups # are public or not reqesters = RequesterProfile.objects.all_as_dict() dbpool = ThreadedConnectionPool(10, 90, 'dbname=%s user=%s password=%s' % ( settings.DATABASES['default']['NAME'], settings.DATABASES['default']['USER'], settings.DATABASES['default']['PASSWORD'])) # collection of group_ids that were already processed - this should # protect us from duplicating data processed_groups = set() total_reward = 0 hitgroups_iter = self.hits_iter() for hg_pack in hitgroups_iter: jobs = [] for hg in hg_pack: if hg['group_id'] in processed_groups: log.debug('Group already in processed_groups, skipping.') continue processed_groups.add(hg['group_id']) j = gevent.spawn(tasks.process_group, hg, crawl.id, reqesters, processed_groups, dbpool) jobs.append(j) total_reward += hg['reward'] * hg['hits_available'] log.debug('processing pack of hitgroups objects') gevent.joinall( jobs, timeout=settings.CRAWLER_GROUP_PROCESSING_TIMEOUT) # check if all jobs ended successfully for job in jobs: if not job.ready(): log.error('Killing job: %s', job) job.kill() if len(processed_groups) >= groups_available: log.info('Skipping empty groups.') # there's no need to iterate over empty groups.. break break # amazon does not like too many requests at once, so give them a # quick rest... gevent.sleep(1) dbpool.closeall() # update crawler object crawl.groups_downloaded = len(processed_groups) crawl.end_time = datetime.datetime.now() crawl.save() work_time = time.time() - _start_time log.info("""Crawl finished: created crawl id: {crawl_id}) total reward value: {total_reward} hits groups downloaded: {processed_groups} hits groups available: {groups_available} work time: {work_time:.2f} seconds """.format(crawl_id=crawl.id, total_reward=total_reward, processed_groups=len(processed_groups), groups_available=groups_available, work_time=work_time)) crawl_downloaded_pc = settings.INCOMPLETE_CRAWL_THRESHOLD crawl_warning_pc = settings.INCOMPLETE_CRAWL_WARNING_THRESHOLD crawl_time_warning = settings.CRAWLER_TIME_WARNING downloaded_pc = float(crawl.groups_downloaded) / groups_available if work_time > crawl_time_warning: log.warning(("Crawl took {0}s which seems a bit too long (more " "than {1}s), you might consider checking if correct mturk " "account is used, ignore this if high number of groups is " "experienced.").format(work_time, crawl_time_warning)) if downloaded_pc < crawl_warning_pc: log.warning(('Only {0}% of hit groups were downloaded, below ' '({1}% warning threshold) please check mturk account ' 'configuration and/or if there are any network-related ' 'problems.').format(downloaded_pc, crawl_warning_pc)) if downloaded_pc < crawl_downloaded_pc: log.warning("This crawl contains far too few groups downloaded to " "available: {0}% < {1}% downloaded threshold and will be " "considered as erroneous ({2}/{3} groups).".format( downloaded_pc, crawl_downloaded_pc, crawl.groups_downloaded, groups_available)) pid.remove_pid()
dbname = cp.get('database','dbname') dbuser = cp.get('database','user') dbhost = cp.get('database','host') dbpass = cp.get('database','dbpass') dbpoolSize = cp.get('database','dbpoolSize') except ConfigParser.NoOptionError, e: print "TBDB.cfg: missing parameter" exit(1) # Create DB connection pool dbpool = ThreadedConnectionPool(2, int(dbpoolSize), "dbname='%s' user='******' host='%s' password='******'"%(dbname,dbuser,dbhost,dbpass)) # Starts each channel/thread for line in file(nodesCfgFile): print line if line[:1] != '#' and len(line.strip()) > 0: moteId,local_port,dev_addr,dev_port = line.split() #settings.append((int(moteId),int(local_port),dev_addr,int(dev_port))) forwarder('',int(local_port),dev_addr,int(dev_port), int(moteId)) try: asyncore.loop() except KeyboardInterrupt, e: print e except asyncore.ExitNow, e: print e # close all DB Pool Connection print "Closing DB connection pool" dbpool.closeall()
class POISearch(object): """Class for performing the POI search. POI search queries the "gesearch" database for the sql queries which were created when the user configures the POI search tab and push/publishes from Fusion. These sql queries are executed on the "gepoi" database to retrieve the search results. Input parameters depend on what search fields have been chosen while configuring the POI seach using the Search Tab Manager in Fusion. For example, if latitude and longitude have been chosen as part of POI search fields, then Valid Inputs are : q=<latitude_value or longitude_value>&DbId=<db_id> or searchTerm=<latitude_value or longitude_value>&DbId=<db_id> """ def __init__(self): """Inits POISearch. Initializes the logger "ge_search". Initializes templates for kml, placemark templates for the KML output. Initializes parameters for establishing a connection to the database. """ self.utils = utils.SearchUtils() constants = geconstants.Constants() configs = self.utils.GetConfigs( os.path.join(geconstants.SEARCH_CONFIGS_DIR, "PoiSearch.conf")) style_template = self.utils.style_template self._jsonp_call = self.utils.jsonp_functioncall self._geom = """ <name>%s</name>, <styleUrl>%s</styleUrl> <description>%s</description>, %s\ """ self._json_geom = """ { "name" : "%s", "description" : "%s", %s } """ self._placemark_template = self.utils.placemark_template self._kml_template = self.utils.kml_template self._json_template = self.utils.json_template self._json_placemark_template = self.utils.json_placemark_template self._host_db_name_by_target_query = constants.host_db_name_by_target_query self._poi_info_by_host_db_name_query = ( constants.poi_info_by_host_db_name_query) self.logger = logging.getLogger("ge_search") poisearch_database = configs.get("searchdatabasename") if not poisearch_database: poisearch_database = constants.defaults.get("poisearch.database") gestream_database = configs.get("streamdatabasename") if not gestream_database: gestream_database = constants.defaults.get("gestream.database") poiquery_database = configs.get("poidatabasename") if not poiquery_database: poiquery_database = constants.defaults.get("poiquery.database") self._search_pool = ThreadedConnectionPool( int(configs.get("minimumconnectionpoolsize")), int(configs.get("maximumconnectionpoolsize")), database=poisearch_database, user=configs.get("user"), host=configs.get("host"), port=int(configs.get("port"))) self._poi_pool = ThreadedConnectionPool( int(configs.get("minimumconnectionpoolsize")), int(configs.get("maximumconnectionpoolsize")), database=poiquery_database, user=configs.get("user"), host=configs.get("host"), port=int(configs.get("port"))) self._stream_pool = ThreadedConnectionPool( int(configs.get("minimumconnectionpoolsize")), int(configs.get("maximumconnectionpoolsize")), database=gestream_database, user=configs.get("user"), host=configs.get("host"), port=int(configs.get("port"))) self._style = style_template.substitute( balloonBgColor=configs.get("balloonstyle.bgcolor"), balloonTextColor=configs.get("balloonstyle.textcolor"), balloonText=configs.get("balloonstyle.text"), iconStyleScale=configs.get("iconstyle.scale"), iconStyleHref=configs.get("iconstyle.href"), lineStyleColor=configs.get("linestyle.color"), lineStyleWidth=configs.get("linestyle.width"), polyStyleColor=configs.get("polystyle.color"), polyStyleColorMode=configs.get("polystyle.colormode"), polyStyleFill=configs.get("polystyle.fill"), polyStyleOutline=configs.get("polystyle.outline"), listStyleHref=configs.get("iconstyle.href")) # Parameters for calculating the bounding box. self.latitude_center = constants.latitude_center self.longitude_center = constants.longitude_center self.latitude_span = constants.latitude_span self.longitude_span = constants.longitude_span self.srid = constants.srid self.usebbox = configs.get("usebbox").lower() == "true" self.expandbbox = configs.get("expandbbox").lower() == "true" # Calculate default bounding box with latitude span = 180(degrees) and # longitude span = 360(degrees). self.world_bounds = self.__CreateBboxFromParameters( self.latitude_center, self.longitude_center, self.latitude_span, self.longitude_span) # Create federated search handler object for # performing super federated search. try: self._fed_search = federated_search_handler.FederatedSearch() except Exception as e: self.logger.warning("Federated search is not available due to an " "unexpected error, %s.", e) self._fed_search = None def __RunPGSQLQuery(self, pool, query, params): """Submits the query to the database and returns tuples. Note: variables placeholder must always be %s in query. Warning: NEVER use Python string concatenation (+) or string parameters interpolation (%) to pass variables to a SQL query string. e.g. SELECT vs_url FROM vs_table WHERE vs_name = 'default_ge'; query = "SELECT vs_url FROM vs_table WHERE vs_name = %s" parameters = ["default_ge"] Args: pool: Pool of database connection's on which the query should be executed. query: SQL SELECT statement. params: sequence of parameters to populate into placeholders. Returns: Results as list of tuples (rows of fields). Raises: psycopg2.Error/Warning in case of error. """ cursor = None query_results = [] query_status = False try: con = pool.getconn() if con: cursor = con.cursor() dsn_info = {} for entry in con.dsn.split(" "): dsn_entry = entry.split("=") dsn_info[dsn_entry[0]] = dsn_entry[1] self.logger.debug("Querying the database %s, at port %s, as user %s on " "hostname %s.", dsn_info["dbname"], dsn_info["port"], dsn_info["user"], dsn_info["host"]) self.logger.debug("Query: %s", query) self.logger.debug("Params: %s", params) cursor.execute(query, params) for row in cursor: if len(row) == 1: query_results.append(row[0]) else: query_results.append(row) query_status = True except psycopg2.pool.PoolError as e: self.logger.error("Exception while querying the database, %s", e) raise exceptions.PoolConnectionException( "Pool Error - Unable to get a connection from the pool.") except psycopg2.Error as e: self.logger.error("Exception while querying the database, %s", e) finally: if con: pool.putconn(con) return query_status, query_results def __QueryPOIInfo(self, target_path): """Gets POI info data for specified target. Queries gestream database to get database info (host_name, db_name) by target path, then queries gesearch database to get poi info data from poi_table by database info. Args: target_path: Published target path. Returns: query_status: True if POI info data has been fetched. poi_info_data: poi info data as a list of tuples (query_str, num_fields) or empty list. Raises: psycopg2.pool.PoolError in case of error while getting a connection from the pool. """ query_status = False poi_info_data = [] # Get host_name and db_name from gestream database for a target path. query_status, query_data = self.__RunPGSQLQuery( self._stream_pool, self._host_db_name_by_target_query, (target_path,)) if not query_status: self.logger.debug("Target path %s not published yet", target_path) else: # query_data is of type [(host_name, db_name)]. # Sample result is [('my_host', # '/gevol/assets/Databases/my_db.kdatabase/gedb.kda/ver002/gedb')] # Fetch (query_str, num_fields) for POISearch from gesearch database # with host_name and db_name as inputs. query_status, poi_info_data = self.__RunPGSQLQuery( self._search_pool, self._poi_info_by_host_db_name_query, query_data[0]) return query_status, poi_info_data def __ConstructResponse(self, poi_results, response_type, original_query): """Prepares response based on response type. Args: poi_results: Query results from the gepoi database. response_type: KML or JSONP. original_query: Search query as entered by the user. Returns: format_status: If response has been properly formatted(True/False). format_response: Response in KML or JSON. """ format_status = False format_response = "" # Remove quotes so they don't interfere with kml or json. safe_query = original_query.replace('"', "").replace("'", "") if response_type == "KML": format_response = self.__ConstructKMLResponse(poi_results, safe_query) format_status = True elif response_type == "JSONP": format_response = self.__ConstructJSONPResponse(poi_results, safe_query) format_status = True else: # This condition may not occur, # as response_type is either KML or JSONP. self.logger.error("Invalid response type %s", response_type) return format_status, format_response def __ConstructKMLResponse(self, search_results, original_query): """Prepares KML response. KML response has the below format: <kml> <Folder> <name/> <StyleURL> --- </StyleURL> <Point> <coordinates/> </Point> </Folder> </kml> Args: search_results: Query results from the gepoi database. original_query: Search query as entered by the user. Returns: kml_response: KML formatted response. """ search_placemarks = "" kml_response = "" lookat_info = "" set_first_element_lookat = True # folder name should include the query parameter(q) if 'displayKeys' # is present in the URL otherwise not. if self.display_keys_string: folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) else: folder_name = ("Grouped results:<![CDATA[<br/>]]> (%s)" % (str(len(search_results)))) fly_to_first_element = str(self.fly_to_first_element).lower() == "true" for result in search_results: geom = "" placemark = "" styleurl = "#placemark_label" description_template = Template("${NAME} = ${VALUE}\n") name, description, geom_data = self.__GetPOIAttributes( original_query, result, description_template) geom = self._geom % (cgi.escape(name), styleurl, cgi.escape(description), geom_data) if fly_to_first_element and set_first_element_lookat: lookat_info = self.utils.GetLookAtInfo(geom_data) set_first_element_lookat = False placemark = self._placemark_template.substitute(geom=geom) search_placemarks += placemark kml_response = self._kml_template.substitute( foldername=folder_name, style=self._style, lookat=lookat_info, placemark=search_placemarks) self.logger.debug("KML response successfully formatted") return kml_response def __ConstructJSONPResponse(self, search_results, original_query): """Prepares JSONP response. { "Folder": { "name": "Latitude X Longitude Y", "Placemark": { "Point": { "coordinates": "X,Y" } } } } Args: search_results: Query results from the gepoi table. original_query: Search query as entered by the user. Returns: jsonp_response: JSONP formatted response. """ search_placemarks = "" search_geoms = "" geoms = "" json_response = "" jsonp_response = "" folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) # If end user submits multiple queries in one POI Search, like x,y # then take x only for POI Search. search_tokens = self.utils.SearchTokensFromString(original_query) for count, result in enumerate(search_results): geom = "" description_template = Template("${NAME}: ${VALUE}<br>") name, description, geom_data = self.__GetPOIAttributes( search_tokens[0], result, description_template) # Remove the outer braces for the json geometry. geom_data = geom_data.strip().lstrip("{").rstrip("}") geom = self._json_geom % (name, description, geom_data) geoms += geom if count < (len(search_results) - 1): geoms += "," if len(search_results) == 1: search_geoms = geoms else: search_geoms = "[" + geoms + "]" search_placemarks = self._json_placemark_template.substitute( geom=search_geoms) json_response = self._json_template.substitute( foldername=folder_name, json_placemark=search_placemarks) # Escape single quotes from json_response. json_response = json_response.replace("'", "\\'") jsonp_response = self._jsonp_call % (self.f_callback, json_response) self.logger.debug("JSONP response successfully formatted") return jsonp_response def __GetPOIAttributes(self, original_query, search_result, desc_template): """Fetch POI attributes like name, description and geometry. Args: original_query: Search query as entered by the user. search_result: Query results from the gepoi table. desc_template: Description template based on the response type. Returns: Name, description and geometry. """ name = "" description = "" # A sample map_entry would be as below. # { # 'field_value': 'US Route 395', # 'field_name': 'name', # 'is_search_display': True, # 'is_searchable': True, # 'is_displayable': True # } for map_entry in search_result: if map_entry["field_name"] != "geom": if map_entry["is_search_display"]: # Find the matched string by checking if the search query # is a sub-string of the POI field values. # This condition is for POI fields which are both # searchable and displayable. if original_query.lower() in map_entry["field_value"].lower(): name = map_entry["field_value"] else: # This condition is for POI fields which are # searchable but not displayable. # Assign "name" attribute to the first POI field value. if not name: name = search_result[1]["field_value"] # Assign all the displayable POI field values # to "description" attribute. if map_entry["is_displayable"]: description += desc_template.substitute( NAME=map_entry["field_name"], VALUE=map_entry["field_value"]) else: geom_data = map_entry["field_value"] return name, description, geom_data def HandleSearchRequest(self, environ): """Fetches the search tokens from form and performs the POI search. Args: environ: A list of environment variables as supplied by the WSGI interface to the POI search application interface. Returns: search_results: A KML/JSONP formatted string which contains search results. response_type: KML or JSON, depending on the end client. """ search_results = "" search_status = False is_super_federated = False # Fetch all the attributes provided by the user. parameters = self.utils.GetParameters(environ) response_type = self.utils.GetResponseType(environ) # Retrieve the function call back name for JSONP response. self.f_callback = self.utils.GetCallback(parameters) original_query = self.utils.GetValue(parameters, "q") # Fetch additional query parameters 'flyToFirstElement' and # 'displayKeys' from URL. self.fly_to_first_element = self.utils.GetValue( parameters, "flyToFirstElement") self.display_keys_string = self.utils.GetValue( parameters, "displayKeys") # Process bounding box flag and parameters. # Get useBBox and expandBBox from query. # useBBox and expandBBox can be set in the additional config parameters # section while creating the search tabs in server admin UI. # Any value specified here should override the value # provided in the PoiSearch.conf config file. usebbox = self.utils.GetValue(parameters, "useBBox") if usebbox: self.usebbox = usebbox.lower() == "true" # Get expandBBox only when useBBox is True. if self.usebbox: expandbbox = self.utils.GetValue(parameters, "expandBBox") if expandbbox: self.expandbbox = expandbbox.lower() == "true" # Default bbox is world_bounds. bbox = self.world_bounds if self.usebbox: # Get ll and spn params from the query. lat_lng = self.utils.GetValue(parameters, "ll") spans = self.utils.GetValue(parameters, "spn") if lat_lng and spans: lat_lng = [float(coord) for coord in lat_lng.split(",")] spans = [float(span) for span in spans.split(",")] # Override default bounding box setting. bbox = self.__CreateBboxFromParameters( lat_lng[0], lat_lng[1], spans[0], spans[1]) if not original_query: # If "q" not available, extract 'searchTerm' parameter value from URL. original_query = self.utils.GetValue(parameters, "searchTerm") # Extract target path from 'SCRIPT_URL'. # SAMPLE SCRIPT_URL IS '/sf2d/POISearch'. parse_res = urlparse.urlparse(environ["SCRIPT_URL"]) match_tp = re.match(r"(.*)/POISearch", parse_res.path) if match_tp: target_path = match_tp.groups()[0] poi_info_status, poi_info_data = self.__QueryPOIInfo(target_path) if original_query and poi_info_status: end_search = False while True: (search_status, search_results) = self.DoSearch( original_query, poi_info_data, response_type, bbox) # If there are no search results and expandbbox is true, # expand the search region by providing a new bounding box # whose latitude and longitude values are double that of the # original. Continue this process until atleast a single search # result is found. if (not search_status and self.usebbox and self.expandbbox and not end_search): # provide new bbox with latitude and longitude span values multiplied # by a factor of 2. spans = [span*2 for span in spans] if self.__IsBBoxValid(lat_lng[0], lat_lng[1], spans[0], spans[1]): bbox = self.__CreateBboxFromParameters( lat_lng[0], lat_lng[1], spans[0], spans[1]) else: # Try search with world bounds. bbox = self.world_bounds end_search = True else: break else: self.logger.error("Empty or incorrect search query received.") if search_status: (search_status, search_results) = self.__ConstructResponse( search_results, response_type, original_query) else: # If no results from POI search, then perform Federated search. if original_query and self._fed_search: # Get "PoiFederated" parameter value. # "PoiFederated" can take values 0 or 1. super_federated_value = self.utils.GetValue( parameters, "PoiFederated") if super_federated_value: # For any value not greater than 0,is_super_federated is 'False' is_super_federated = int(super_federated_value) > 0 # Perform Federated search if the POI search results in # zero(0) results and 'is_super_federated' value is 'True'. if is_super_federated: self._fed_search.f_callback = self.f_callback self._fed_search.fly_to_first_element = self.fly_to_first_element self._fed_search.display_keys_string = self.display_keys_string (search_status, search_results) = self._fed_search.DoSearch( original_query, response_type) # if no results from POI search or Federated search. if not search_status: folder_name = "No results were returned." search_results = self.utils.NoSearchResults( folder_name, self._style, response_type, self.f_callback) return (search_results, response_type) def DoSearch(self, search_expression, poi_info_data, response_type, bbox): """Performs the poi search and returns the results. Args: search_expression: Search query as entered by the user. poi_info_data: POI info as a list of tuples (query_str, num_fields). response_type: Response type can be KML or JSONP, depending on the client. bbox: The bounding box string. Returns: sql_query_status: Whether data could be queried from "gepoi" database. poi_results: Query results as a list. """ sql_query_status = False poi_data = [] normalised_search_expression = search_expression.lower() for sql_query in poi_info_data: cur_status, cur_data = self.__ParseSQLResults( normalised_search_expression, sql_query, response_type, bbox) if cur_status: sql_query_status = cur_status poi_data += cur_data self.logger.debug("poi search returned %s results", len(poi_data)) self.logger.debug("results: %s", poi_data) return sql_query_status, poi_data def __ParseSQLResults(self, search_expression, poi_info, response_type, bbox): """Performs the poi search and returns the results. Args: search_expression: Normalized search expression. poi_info: Tuple containing sql_query, number of columns, style id and poi id as extracted from "poi_table" table of "gesearch" database. response_type: Either KML/JSONP, depending on the client. bbox: The bounding box string. Returns: poi_query_status: Whether data could be queried from "gepoi" database. poi_query_results: Query results as a list. Raises: psycopg2.pool.PoolError in case of error while getting a connection from the pool. """ poi_query_status = False poi_query_results = [] query_results = [] poi_query = "" sql_query = poi_info[0] num_of_input_fields = poi_info[1] try: search_tokens = self.utils.SearchTokensFromString(search_expression) except Exception as e: raise exceptions.BadQueryException( "Couldn't parse search term. Error: %s" % e) self.logger.debug("Parsed search tokens: %s", search_tokens) params = ["%" + entry + "%" for entry in search_tokens] num_params = len(params) if num_params == 0: raise exceptions.BadQueryException("No search term.") if num_params > num_of_input_fields: params = params[:num_of_input_fields] elif num_params < num_of_input_fields: if num_params == 1: params.extend([params[0]] * (num_of_input_fields - num_params)) else: params.extend(["^--IGNORE--"] * (num_of_input_fields - num_params)) accum_func = self.utils.GetAccumFunc(response_type) # sql queries retrieved from "gesearch" database has "?" for # input arguments, but postgresql supports "%s". # Hence, replacing "?" with "%s". sql_stmt = sql_query.replace("?", "%s") # Extract the displayable and searchable POI fields from the POI # queries retrieved from "poi_table". matched = re.match(r"(\w*)\s*(Encode.*geom)(.*)(FROM.*)(WHERE )(\(.*\))", sql_stmt) if matched: (sub_query1, unused_sub_query2, sub_query3, sub_query4, sub_query5, sub_query6) = matched.groups() # sub_query2 need to be replaced with a new query # as per the response_type. # PYLINT throws warning that sub_query2 has not been used, it's # not required and can be ignored. geom_stmt = "%s(the_geom) AS the_geom" % (accum_func) poi_query = "%s %s%s%s%s%s" % (sub_query1, geom_stmt, sub_query3, sub_query4, sub_query5, sub_query6) poi_query += " AND ( the_geom && %s )" % bbox # Displayable POI fields appear between SELECT and FROM # clause of the POI queries, as in below example. # SELECT ST_AsKML(the_geom) AS the_geom, "rpoly_", "fnode_" # FROM gepoi_7 WHERE ( lower("rpoly_") LIKE %s OR lower("name") LIKE %s ). # "rpoly_" and "fnode_" are display fields in the above example. display_fields = [field.replace("\"", "").strip() for field in filter(len, sub_query3.split(","))] # Insert geom at the 0th index for easy retrieval. display_fields.insert(0, "geom") # Searchable POI fields appear after the WHERE # clause of the POI queries, as in below example. # SELECT ST_AsKML(the_geom) AS the_geom, "rpoly_", "fnode_" # FROM gepoi_7 WHERE ( lower("rpoly_") LIKE %s OR lower("name") LIKE %s ). # "rpoly_" and "name" are searchable fields in the above example. searchable_fields = [ entry.strip().strip("OR").strip().strip("lower").strip("\\(\\)\"") for entry in filter(len, sub_query6.strip("\\(\\) ").split("LIKE %s")) ] # Insert geom at the 0th index for easy retrieval. searchable_fields.insert(0, "geom") if poi_query: poi_query_status, query_results = self.__RunPGSQLQuery( self._poi_pool, poi_query, params) # Create a map which will have the POI fields values # retrieved from the gepoi_X tables and # other information like displayable or searchable or both # based on the display and search labels retrieved above. # Some sample maps are as below. # 1) { # 'field_value': 'State Route 55', # 'field_name': 'name', # 'is_search_display': True, # 'is_searchable': True, # 'is_displayable': True #} # 2) {'field_value': '0', # 'field_name': 'rpoly_', # 'is_search_display': True, # 'is_searchable': True, # 'is_displayable': True}. # 3) {'field_value': '22395', # 'field_name': 'fnode_', # 'is_search_display': True, # 'is_searchable': True, # 'is_displayable': True}. # These maps would be used when creating the KML and JSONP responses # The different flags (is_displayable,is_searchable etc) allow for # easier retrieval of data based on our requirements. for entry in query_results: field_name_value = [] for field_name, field_value in zip(display_fields, entry): temp = {} temp["field_name"] = field_name temp["field_value"] = field_value temp["is_searchable"] = (field_name in searchable_fields) # "is_displayable" is always True as we are iterating over # the display(only) POI fields. temp["is_displayable"] = True temp["is_search_display"] = ( temp["is_displayable"] and temp["is_searchable"]) field_name_value.append(temp) for field_name in searchable_fields: temp = {} if field_name not in display_fields: temp["field_name"] = field_name temp["field_value"] = "" # "is_searchable" is always True as we are iterating over # the search(only) POI fields. temp["is_searchable"] = True temp["is_displayable"] = (field_name in display_fields) temp["is_search_display"] = ( temp["is_displayable"] and temp["is_searchable"]) field_name_value.append(temp) poi_query_results.append(field_name_value) return poi_query_status, poi_query_results def __CreateBboxFromParameters(self, latcenter, loncenter, latspan, lonspan): """Create a bounding box string for bounding box queries. Args: latcenter: latitude centre in degrees. loncenter: longitude centre in degrees. latspan: full latitude span in degrees. lonspan: full longitude span in degrees. Returns: The bounding box string. """ (xmin, xmax, ymin, ymax) = self.__GetBBoxBounds( latcenter, loncenter, latspan, lonspan) bbox = "ST_SetSRID('BOX3D(%s %s,%s %s)'::box3d,%s)" %( xmin, ymin, xmax, ymax, self.srid) return bbox def __GetBBoxBounds(self, latcenter, loncenter, latspan, lonspan): """Get bounding box coordinates. Args: latcenter: latitude centre in degrees. loncenter: longitude centre in degrees. latspan: full latitude span in degrees. lonspan: full longitude span in degrees. Returns: The bounding box coordinates. """ ymin = "%.1f" %(latcenter - (latspan / 2)) ymax = "%.1f" %(latcenter + (latspan / 2)) xmin = "%.1f" %(loncenter - (lonspan / 2)) xmax = "%.1f" %(loncenter + (lonspan / 2)) return (float(xmin), float(xmax), float(ymin), float(ymax)) def __IsBBoxValid(self, latcenter, loncenter, latspan, lonspan): """Check if the bounding box is valid. Args: latcenter: latitude centre in degrees. loncenter: longitude centre in degrees. latspan: full latitude span in degrees. lonspan: full longitude span in degrees. Returns: Validity of the bounding box. """ is_bbox_valid = False (xmin, xmax, ymin, ymax) = self.__GetBBoxBounds( latcenter, loncenter, latspan, lonspan) if xmin >= -180.0 and xmax <= 180.0 and ymin >= -90.0 and ymax <= 90.0: is_bbox_valid = True return is_bbox_valid def __del__(self): """Closes the connection pool created in __init__. """ if self._poi_pool: self._poi_pool.closeall() if self._search_pool: self._search_pool.closeall() if self._stream_pool: self._stream_pool.closeall()
class PostgresqlUtil(DBInterface): def __init__(self): super(PostgresqlUtil, self).__init__() self.conn_pool = ThreadedConnectionPool( minconn=1, maxconn=32, database=config.POSTGRESQL_DATABASE, user=config.POSTGRESQL_USER_NAME, password=config.POSTGRESQL_PASSWORD, host=config.POSTGRESQL_HOST, port=config.POSTGRESQL_PORT) def get_conn(self): conn = self.conn_pool.getconn() return conn def put_conn(self, conn): self.conn_pool.putconn(conn) def close(self): self.conn_pool.closeall() # 插入 def insert(self, table, data, **kwargs): insert_sql = generate_sql_string("INSERT", table, data) conn = self.get_conn() cur = conn.cursor() try: cur.execute(insert_sql) except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn) # 查找 def select(self, table, condition, **kwargs): select_sql = generate_sql_string("SELECT", table, dict_condition=condition) conn = self.get_conn() cur = conn.cursor() res = [] try: cur.execute(select_sql) res = cur.fetchall() except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn) return res # 删除 def delete(self, table, condition, **kwargs): delete_sql = generate_sql_string("DELETE", table, dict_condition=condition) conn = self.get_conn() cur = conn.cursor() try: cur.execute(delete_sql) except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn) # 修改 def update(self, table, condition, data, upsert=False, **kwargs): item = self.select(table, condition) if len(item) == 0 and upsert is False: raise ValueError() elif len(item) == 0 and upsert is True: self.insert(table, data) else: # len(item) > 0 update_sql = generate_sql_string("UPDATE", table, data, condition) conn = self.get_conn() cur = conn.cursor() try: cur.execute(update_sql) except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn)
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} def __new__(cls, database_name="template1"): if database_name in cls._databases: return cls._databases[database_name] return DatabaseInterface.__new__(cls, database_name=database_name) def __init__(self, database_name="template1"): super(Database, self).__init__(database_name=database_name) self._databases.setdefault(database_name, self) def connect(self): if self._connpool is not None: return self logger = logging.getLogger("database") logger.info('connect to "%s"' % self.database_name) host = CONFIG["db_host"] and "host=%s" % CONFIG["db_host"] or "" port = CONFIG["db_port"] and "port=%s" % CONFIG["db_port"] or "" name = "dbname=%s" % self.database_name user = CONFIG["db_user"] and "user=%s" % CONFIG["db_user"] or "" password = CONFIG["db_password"] and "password=%s" % CONFIG["db_password"] or "" minconn = int(CONFIG["db_minconn"]) or 1 maxconn = int(CONFIG["db_maxconn"]) or 64 dsn = "%s %s %s %s %s" % (host, port, name, user, password) self._connpool = ThreadedConnectionPool(minconn, maxconn, dsn) return self def cursor(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() conn = self._connpool.getconn() if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) cursor = Cursor(self._connpool, conn, self) if readonly: cursor.execute("SET TRANSACTION READ ONLY") return cursor def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None def create(self, cursor, database_name): cursor.execute('CREATE DATABASE "' + database_name + '" ' "TEMPLATE template0 ENCODING 'unicode'") Database._list_cache = None def drop(self, cursor, database_name): cursor.execute('DROP DATABASE "' + database_name + '"') Database._list_cache = None def get_version(self, cursor): if self.database_name not in self._version_cache: cursor.execute("SELECT version()") version, = cursor.fetchone() self._version_cache[self.database_name] = tuple(map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.database_name] @staticmethod def dump(database_name): from trytond.tools import exec_pg_command_pipe cmd = ["pg_dump", "--format=c", "--no-owner"] if CONFIG["db_user"]: cmd.append("--username="******"db_user"]) if CONFIG["db_host"]: cmd.append("--host=" + CONFIG["db_host"]) if CONFIG["db_port"]: cmd.append("--port=" + CONFIG["db_port"]) cmd.append(database_name) pipe = exec_pg_command_pipe(*tuple(cmd)) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception("Couldn't dump database!") return data @staticmethod def restore(database_name, data): from trytond.tools import exec_pg_command_pipe database = Database().connect() cursor = database.cursor(autocommit=True) database.create(cursor, database_name) cursor.commit() cursor.close() cmd = ["pg_restore", "--no-owner"] if CONFIG["db_user"]: cmd.append("--username="******"db_user"]) if CONFIG["db_host"]: cmd.append("--host=" + CONFIG["db_host"]) if CONFIG["db_port"]: cmd.append("--port=" + CONFIG["db_port"]) cmd.append("--dbname=" + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ["TMP"] or "C:\\") + os.tmpnam() with open(tmpfile, "wb") as fp: fp.write(data) args2 = list(args2) args2.append(" " + tmpfile) args2 = tuple(args2) pipe = exec_pg_command_pipe(*args2) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception("Couldn't restore database") database = Database(database_name).connect() cursor = database.cursor() if not cursor.test(): cursor.close() database.close() raise Exception("Couldn't restore database!") cursor.close() database.close() Database._list_cache = None return True @staticmethod def list(cursor): now = time.time() timeout = int(CONFIG["session_timeout"]) res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res db_user = CONFIG["db_user"] if not db_user and os.name == "posix": db_user = pwd.getpwuid(os.getuid())[0] if not db_user: cursor.execute( "SELECT usename " "FROM pg_user " "WHERE usesysid = (" "SELECT datdba " "FROM pg_database " "WHERE datname = %s)", (CONFIG["db_name"],), ) res = cursor.fetchone() db_user = res and res[0] if db_user: cursor.execute( "SELECT datname " "FROM pg_database " "WHERE datdba = (" "SELECT usesysid " "FROM pg_user " "WHERE usename=%s) " "AND datname not in " "('template0', 'template1', 'postgres') " "ORDER BY datname", (db_user,), ) else: cursor.execute( "SELECT datname " "FROM pg_database " "WHERE datname not in " "('template0', 'template1','postgres') " "ORDER BY datname" ) res = [] for (db_name,) in cursor.fetchall(): db_name = db_name.encode("utf-8") try: database = Database(db_name).connect() except Exception: continue cursor2 = database.cursor() if cursor2.test(): res.append(db_name) cursor2.close(close=True) else: cursor2.close(close=True) database.close() Database._list_cache = res Database._list_cache_timestamp = now return res @staticmethod def init(cursor): from trytond.modules import get_module_info sql_file = os.path.join(os.path.dirname(__file__), "init.sql") with open(sql_file) as fp: for line in fp.read().split(";"): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ("ir", "res", "webdav"): state = "uninstalled" if module in ("ir", "res"): state = "to install" info = get_module_info(module) cursor.execute("SELECT NEXTVAL('ir_module_module_id_seq')") module_id = cursor.fetchone()[0] cursor.execute( "INSERT INTO ir_module_module " "(id, create_uid, create_date, name, state) " "VALUES (%s, %s, now(), %s, %s)", (module_id, 0, module, state), ) for dependency in info.get("depends", []): cursor.execute( "INSERT INTO ir_module_module_dependency " "(create_uid, create_date, module, name) " "VALUES (%s, now(), %s, %s)", (0, module_id, dependency), )
class Database(rigor.database.Database): """ Container for a database connection pool """ def __init__(self, database): super(Database, self).__init__(database) register_type(psycopg2.extensions.UNICODE) register_uuid() dsn = Database.build_dsn(database) self._pool = ThreadedConnectionPool(config.get('database', 'min_database_connections'), config.get('database', 'max_database_connections'), dsn) @staticmethod def build_dsn(database): """ Builds the database connection string from config values """ dsn = "dbname='{0}' host='{1}'".format(database, config.get('database', 'host')) try: ssl = config.getboolean('database', 'ssl') if ssl: dsn += " sslmode='require'" except ConfigParser.Error: pass try: username = config.get('database', 'username') dsn += " user='******'".format(username) except ConfigParser.Error: pass try: password = config.get('database', 'password') dsn += " password='******'".format(password) except ConfigParser.Error: pass return dsn @staticmethod @template def create(name): """ Creates a new database with the given name """ return "CREATE DATABASE {0};".format(name) @staticmethod @template def drop(name): """ Drops the database with the given name """ return "DROP DATABASE {0};".format(name) @staticmethod @template def clone(source, destination): """ Copies the source database to a new destination database. This may fail if the source database is in active use. """ return "CREATE DATABASE {0} WITH TEMPLATE {1};".format(destination, source) @contextmanager def get_cursor(self, commit=True): """ Gets a cursor from a connection in the pool """ connection = self._pool.getconn() cursor = connection.cursor(cursor_factory=RigorCursor) try: yield cursor except psycopg2.IntegrityError as error: exc_info = sys.exc_info() self.rollback(cursor) raise rigor.database.IntegrityError, exc_info[1], exc_info[2] except psycopg2.DatabaseError as error: exc_info = sys.exc_info() self.rollback(cursor) raise rigor.database.DatabaseError, exc_info[1], exc_info[2] except: exc_info = sys.exc_info() self.rollback(cursor) raise exc_info[0], exc_info[1], exc_info[2] else: if commit: self.commit(cursor) else: self.rollback(cursor) def _close_cursor(self, cursor): """ Closes a cursor and releases the connection to the pool """ cursor.close() self._pool.putconn(cursor.connection) def commit(self, cursor): """ Commits the transaction, then closes the cursor """ cursor.connection.commit() self._close_cursor(cursor) def rollback(self, cursor): """ Rolls back the transaction, then closes the cursor """ cursor.connection.rollback() self._close_cursor(cursor) def __del__(self): self._pool.closeall()
class connection(object): def __init__(self, db_info=None, hstore=False, log=None, logf=None, min=1, max=100, default_cursor=DictCursor): if db_info is None: raise ValueError("Invalid connection Params") self.pool = ThreadedConnectionPool( min, max, database=db_info['database'], user=db_info['user'], password=db_info['password'], host=db_info['host'], port=db_info['port'], ) self.hstore = hstore self.log = log self.logf = logf or (lambda cursor: cursor.query.decode()) self.default_cursor = default_cursor self.prepared_statement_id = 0 def prepare(self, statement, params=None, name=None, call_type=None): """ >>> db = connection() >>> p1 = db.prepare('SELECT name FROM doctest_t1 WHERE id = $1') >>> p2 = db.prepare('UPDATE doctest_t1 set name = $2 WHERE id = $1',('int','text')) >>> db.execute(p2,(1,'xxxxx')) 1 >>> db.query_one(p1,(1,)) ['xxxxx'] >>> db.execute(p2,(1,'aaaaa')) 1 >>> db.query_one(p1,(1,)) ['aaaaa'] """ if not name: self.prepared_statement_id += 1 name = '_pstmt_%03.3d' % self.prepared_statement_id if params: params = '(' + ','.join(params) + ')' else: params = '' with self.cursor() as c: c.execute('PREPARE %s %s AS %s' % (name, params, statement)) if call_type is None: if statement.lower().startswith('select'): call_type = 'query' else: call_type = 'execute' return PreparedStatement(self, name, call_type) def shutdown(self): if self.pool: self.pool.closeall() self.pool = None def cursor(self, cursor_factory=None): return cursor(self.pool, self.default_cursor or cursor_factory, self.hstore, self.log, self.logf) def __del__(self): self.shutdown() def __getattr__(self, name): def _wrapper(*args, **kwargs): with self.cursor() as c: return getattr(c, name)(*args, **kwargs) return _wrapper
class connection(object): def __init__(self, url=None, hstore=False, log=None, logf=None, min=1, max=5, default_cursor=DictCursor): params = urlparse(url or os.environ.get('DATABASE_URL') or 'postgres://localhost/') if params.scheme != 'postgres': raise ValueError( "Invalid connection string (postgres://user@pass:host/db?param=value)" ) self.pool = ThreadedConnectionPool( min, max, database=params.path[1:] or parse_qs(params.query).get('dbname'), user=params.username or parse_qs(params.query).get('user'), password=params.password or parse_qs(params.query).get('password'), host=params.hostname or parse_qs(params.query).get('host'), port=params.port or parse_qs(params.query).get('port'), ) self.hstore = hstore self.log = log self.logf = logf or (lambda cursor: cursor.query.decode()) self.default_cursor = default_cursor self.prepared_statement_id = 0 def prepare(self, statement, params=None, name=None, call_type=None): """ >>> db = connection() >>> p1 = db.prepare('SELECT name FROM doctest_t1 WHERE id = $1') >>> p2 = db.prepare('UPDATE doctest_t1 set name = $2 WHERE id = $1',('int','text')) >>> db.execute(p2,(1,'xxxxx')) 1 >>> db.query_one(p1,(1,)) ['xxxxx'] >>> db.execute(p2,(1,'aaaaa')) 1 >>> db.query_one(p1,(1,)) ['aaaaa'] """ if not name: self.prepared_statement_id += 1 name = '_pstmt_%03.3d' % self.prepared_statement_id if params: params = '(' + ','.join(params) + ')' else: params = '' with self.cursor() as c: c.execute('PREPARE %s %s AS %s' % (name, params, statement)) if call_type is None: if statement.lower().startswith('select'): call_type = 'query' else: call_type = 'execute' return PreparedStatement(self, name, call_type) def shutdown(self): if self.pool: self.pool.closeall() self.pool = None def cursor(self, cursor_factory=None): return cursor(self.pool, cursor_factory or self.default_cursor, self.hstore, self.log, self.logf) def __del__(self): self.shutdown() def __getattr__(self, name): def _wrapper(*args, **kwargs): with self.cursor() as c: return getattr(c, name)(*args, **kwargs) return _wrapper
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} flavor = Flavor(ilike=True) def __new__(cls, name='template1'): if name in cls._databases: return cls._databases[name] return DatabaseInterface.__new__(cls, name=name) def __init__(self, name='template1'): super(Database, self).__init__(name=name) self._databases.setdefault(name, self) self._search_path = None self._current_user = None @classmethod def dsn(cls, name): uri = parse_uri(config.get('database', 'uri')) assert uri.scheme == 'postgresql' host = uri.hostname and "host=%s" % uri.hostname or '' port = uri.port and "port=%s" % uri.port or '' name = "dbname=%s" % name user = uri.username and "user=%s" % uri.username or '' password = ("password=%s" % urllib.unquote_plus(uri.password) if uri.password else '') return '%s %s %s %s %s' % (host, port, name, user, password) def connect(self): if self._connpool is not None: return self logger.info('connect to "%s"', self.name) minconn = config.getint('database', 'minconn', default=1) maxconn = config.getint('database', 'maxconn', default=64) self._connpool = ThreadedConnectionPool(minconn, maxconn, self.dsn(self.name)) return self def get_connection(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() for count in range(config.getint('database', 'retry'), -1, -1): try: conn = self._connpool.getconn() break except PoolError: if count and not self._connpool.closed: logger.info('waiting a connection') time.sleep(1) continue raise if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) if readonly: cursor = conn.cursor() cursor.execute('SET TRANSACTION READ ONLY') return conn def put_connection(self, connection, close=False): self._connpool.putconn(connection, close=close) def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None @classmethod def create(cls, connection, database_name): cursor = connection.cursor() cursor.execute('CREATE DATABASE "' + database_name + '" ' 'TEMPLATE template0 ENCODING \'unicode\'') connection.commit() cls._list_cache = None def drop(self, connection, database_name): cursor = connection.cursor() cursor.execute('DROP DATABASE "' + database_name + '"') Database._list_cache = None def get_version(self, connection): if self.name not in self._version_cache: cursor = connection.cursor() cursor.execute('SELECT version()') version, = cursor.fetchone() self._version_cache[self.name] = tuple( map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.name] @staticmethod def dump(database_name): from trytond.tools import exec_command_pipe cmd = ['pg_dump', '--format=c', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: # if db_password is set in configuration we should pass # an environment variable PGPASSWORD to our subprocess # see libpg documentation env['PGPASSWORD'] = uri.password cmd.append(database_name) pipe = exec_command_pipe(*tuple(cmd), env=env) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception('Couldn\'t dump database!') return data @staticmethod def restore(database_name, data): from trytond.tools import exec_command_pipe database = Database().connect() connection = database.get_connection(autocommit=True) database.create(connection, database_name) database.close() cmd = ['pg_restore', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: env['PGPASSWORD'] = uri.password cmd.append('--dbname=' + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam() with open(tmpfile, 'wb') as fp: fp.write(data) args2 = list(args2) args2.append(' ' + tmpfile) args2 = tuple(args2) pipe = exec_command_pipe(*args2, env=env) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception('Couldn\'t restore database') database = Database(database_name).connect() cursor = database.get_connection().cursor() if not database.test(): cursor.close() database.close() raise Exception('Couldn\'t restore database!') cursor.close() database.close() Database._list_cache = None return True def list(self): now = time.time() timeout = config.getint('session', 'timeout') res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res connection = self.get_connection() cursor = connection.cursor() cursor.execute('SELECT datname FROM pg_database ' 'WHERE datistemplate = false ORDER BY datname') res = [] for db_name, in cursor: try: with connect(self.dsn(db_name)) as conn: if self._test(conn): res.append(db_name) except Exception: continue self.put_connection(connection) Database._list_cache = res Database._list_cache_timestamp = now return res def init(self): from trytond.modules import get_module_info connection = self.get_connection() cursor = connection.cursor() sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ('ir', 'res'): state = 'uninstalled' if module in ('ir', 'res'): state = 'to install' info = get_module_info(module) cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')') module_id = cursor.fetchone()[0] cursor.execute( 'INSERT INTO ir_module ' '(id, create_uid, create_date, name, state) ' 'VALUES (%s, %s, now(), %s, %s)', (module_id, 0, module, state)) for dependency in info.get('depends', []): cursor.execute( 'INSERT INTO ir_module_dependency ' '(create_uid, create_date, module, name) ' 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency)) connection.commit() self.put_connection(connection) def test(self): connection = self.get_connection() is_tryton_database = self._test(connection) self.put_connection(connection) return is_tryton_database @classmethod def _test(cls, connection): cursor = connection.cursor() cursor.execute( 'SELECT 1 FROM information_schema.tables ' 'WHERE table_name IN %s', (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu', 'res_user', 'res_group', 'ir_module', 'ir_module_dependency', 'ir_translation', 'ir_lang'), )) return len(cursor.fetchall()) != 0 def nextid(self, connection, table): cursor = connection.cursor() cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')") return cursor.fetchone()[0] def setnextid(self, connection, table, value): cursor = connection.cursor() cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value) def currid(self, connection, table): cursor = connection.cursor() cursor.execute('SELECT last_value FROM "' + table + '_id_seq"') return cursor.fetchone()[0] def lock(self, connection, table): cursor = connection.cursor() cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table) def has_constraint(self): return True def has_multirow_insert(self): return True def get_table_schema(self, connection, table_name): cursor = connection.cursor() for schema in self.search_path: cursor.execute( 'SELECT 1 ' 'FROM information_schema.tables ' 'WHERE table_name = %s AND table_schema = %s', (table_name, schema)) if cursor.rowcount: return schema @property def current_user(self): if self._current_user is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SELECT current_user') self._current_user = cursor.fetchone()[0] finally: self.put_connection(connection) return self._current_user @property def search_path(self): if self._search_path is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SHOW search_path') path, = cursor.fetchone() special_values = { 'user': self.current_user, } self._search_path = [ unescape_quote( replace_special_values(p.strip(), **special_values)) for p in path.split(',') ] finally: self.put_connection(connection) return self._search_path
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} flavor = Flavor(ilike=True) def __new__(cls, name='template1'): if name in cls._databases: return cls._databases[name] return DatabaseInterface.__new__(cls, name=name) def __init__(self, name='template1'): super(Database, self).__init__(name=name) self._databases.setdefault(name, self) self._search_path = None self._current_user = None @classmethod def dsn(cls, name): uri = parse_uri(config.get('database', 'uri')) assert uri.scheme == 'postgresql' host = uri.hostname and "host=%s" % uri.hostname or '' port = uri.port and "port=%s" % uri.port or '' name = "dbname=%s" % name user = uri.username and "user=%s" % uri.username or '' password = ("password=%s" % urllib.unquote_plus(uri.password) if uri.password else '') return '%s %s %s %s %s' % (host, port, name, user, password) def connect(self): if self._connpool is not None: return self logger.info('connect to "%s"', self.name) minconn = config.getint('database', 'minconn', default=1) maxconn = config.getint('database', 'maxconn', default=64) self._connpool = ThreadedConnectionPool( minconn, maxconn, self.dsn(self.name)) return self def get_connection(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() conn = self._connpool.getconn() if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) if readonly: cursor = conn.cursor() cursor.execute('SET TRANSACTION READ ONLY') conn.cursor_factory = PerfCursor return conn def put_connection(self, connection, close=False): self._connpool.putconn(connection, close=close) def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None @classmethod def create(cls, connection, database_name): cursor = connection.cursor() cursor.execute('CREATE DATABASE "' + database_name + '" ' 'TEMPLATE template0 ENCODING \'unicode\'') connection.commit() cls._list_cache = None def drop(self, connection, database_name): cursor = connection.cursor() cursor.execute('DROP DATABASE "' + database_name + '"') Database._list_cache = None def get_version(self, connection): if self.name not in self._version_cache: cursor = connection.cursor() cursor.execute('SELECT version()') version, = cursor.fetchone() self._version_cache[self.name] = tuple(map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.name] @staticmethod def dump(database_name): from trytond.tools import exec_command_pipe cmd = ['pg_dump', '--format=c', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: # if db_password is set in configuration we should pass # an environment variable PGPASSWORD to our subprocess # see libpg documentation env['PGPASSWORD'] = uri.password cmd.append(database_name) pipe = exec_command_pipe(*tuple(cmd), env=env) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception('Couldn\'t dump database!') return data @staticmethod def restore(database_name, data): from trytond.tools import exec_command_pipe database = Database().connect() connection = database.get_connection(autocommit=True) database.create(connection, database_name) database.close() cmd = ['pg_restore', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: env['PGPASSWORD'] = uri.password cmd.append('--dbname=' + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam() with open(tmpfile, 'wb') as fp: fp.write(data) args2 = list(args2) args2.append(' ' + tmpfile) args2 = tuple(args2) pipe = exec_command_pipe(*args2, env=env) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception('Couldn\'t restore database') database = Database(database_name).connect() cursor = database.get_connection().cursor() if not database.test(): cursor.close() database.close() raise Exception('Couldn\'t restore database!') cursor.close() database.close() Database._list_cache = None return True def list(self): now = time.time() timeout = config.getint('session', 'timeout') res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res connection = self.get_connection() cursor = connection.cursor() cursor.execute('SELECT datname FROM pg_database ' 'WHERE datistemplate = false ORDER BY datname') res = [] for db_name, in cursor: try: with connect(self.dsn(db_name)) as conn: if self._test(conn): res.append(db_name) except Exception: continue self.put_connection(connection) Database._list_cache = res Database._list_cache_timestamp = now return res def init(self): from trytond.modules import get_module_info connection = self.get_connection() cursor = connection.cursor() sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ('ir', 'res'): state = 'uninstalled' if module in ('ir', 'res'): state = 'to install' info = get_module_info(module) cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')') module_id = cursor.fetchone()[0] cursor.execute('INSERT INTO ir_module ' '(id, create_uid, create_date, name, state) ' 'VALUES (%s, %s, now(), %s, %s)', (module_id, 0, module, state)) for dependency in info.get('depends', []): cursor.execute('INSERT INTO ir_module_dependency ' '(create_uid, create_date, module, name) ' 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency)) connection.commit() self.put_connection(connection) def test(self): connection = self.get_connection() is_tryton_database = self._test(connection) self.put_connection(connection) return is_tryton_database @classmethod def _test(cls, connection): cursor = connection.cursor() cursor.execute('SELECT 1 FROM information_schema.tables ' 'WHERE table_name IN %s', (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu', 'res_user', 'res_group', 'ir_module', 'ir_module_dependency', 'ir_translation', 'ir_lang'),)) return len(cursor.fetchall()) != 0 def nextid(self, connection, table): cursor = connection.cursor() cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')") return cursor.fetchone()[0] def setnextid(self, connection, table, value): cursor = connection.cursor() cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value) def currid(self, connection, table): cursor = connection.cursor() cursor.execute('SELECT last_value FROM "' + table + '_id_seq"') return cursor.fetchone()[0] def lock(self, connection, table): cursor = connection.cursor() cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table) def has_constraint(self): return True def has_multirow_insert(self): return True def get_table_schema(self, connection, table_name): cursor = connection.cursor() for schema in self.search_path: cursor.execute('SELECT 1 ' 'FROM information_schema.tables ' 'WHERE table_name = %s AND table_schema = %s', (table_name, schema)) if cursor.rowcount: return schema @property def current_user(self): if self._current_user is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SELECT current_user') self._current_user = cursor.fetchone()[0] finally: self.put_connection(connection) return self._current_user @property def search_path(self): if self._search_path is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SHOW search_path') path, = cursor.fetchone() special_values = { 'user': self.current_user, } self._search_path = [ unescape_quote(replace_special_values( p.strip(), **special_values)) for p in path.split(',')] finally: self.put_connection(connection) return self._search_path