def putconn(self, conn): """ Returns connection back to pool. """ #calledBy = traceback.extract_stack()[-2] #logging.info("PUTCONN - FILE: " + calledBy[0] + ", LINE: " + str(calledBy[1]) + ", METHOD: " + calledBy[2]) ThreadedConnectionPool.putconn(self, conn)
class Database: def __init__(self, connect_param): self.__connect_param = connect_param self.__pool = ThreadedConnectionPool(0, 10, self.__connect_param) # get cursor and test it # cur = self.cursor() # cur.execute('SHOW transaction_read_only') # standby = cur.fetchone() # cur.close() def get_connection(self): return self.__pool.getconn() def put_connection(self, connection): self.__pool.putconn(connection)
class Database(): def __init__(self, config): logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO) self._pool = ThreadedConnectionPool(1, 10, database=config['DB_DATABASE'], user=config['DB_USER'], password=config['DB_PASSWORD'], host=config['DB_HOST'], async=False) def get_connection(self): return self._pool.getconn() def put_away_connection(self, con): self._pool.putconn(con)
class DB: def __init__(self, *args, **kwargs): self.pool_params = (args, kwargs) self.pool = None self.campaigns = Campaigns(self) self.worksets = Worksets(self) self.tasks = Tasks(self) self.labels = Labels(self) self.logger = logging.getLogger(__name__) def _initialize_pool(self): if self.pool is None: logger.info("Initializing connection pool.") args, kwargs = self.pool_params self.pool = ThreadedConnectionPool( *args, cursor_factory=RealDictCursor, **kwargs) def execute(self, sql): with self.transaction() as transactor: cursor = transactor.cursor() cursor.execute(sql) return cursor @contextmanager def transaction(self): """Provides a transactional scope around a series of operations.""" self._initialize_pool() conn = self.pool.getconn() try: yield conn conn.commit() except: conn.rollback() raise finally: self.pool.putconn(conn) @classmethod def from_config(cls, config): # Copy config as kwargs params = {k: v for k, v in config['database'].items()} params['minconn'] = params.get('minconn', 1) params['maxconn'] = params.get('maxconn', 5) return cls(**params)
class ConnectionPool: # no test coverage """https://gist.github.com/jeorgen/4eea9b9211bafeb18ada""" is_setup = False def setup(self): self.last_seen_process_id = os.getpid() self._init() self.is_setup = True def _init(self): self._pool = ThreadedConnectionPool(1, 10, database=config.db_name, **credentials) def _getconn(self) -> connection: current_pid = os.getpid() if not (current_pid == self.last_seen_process_id): self._init() log.debug( f"New id is {current_pid}, old id was {self.last_seen_process_id}" ) self.last_seen_process_id = current_pid conn = self._pool.getconn() return conn def _putconn(self, conn: connection): return self._pool.putconn(conn) def closeall(self): self._pool.closeall() @contextmanager def get_connection(self) -> t.Generator[connection, None, None]: try: conn = self._getconn() yield conn finally: self._putconn(conn) @contextmanager def get_cursor(self, commit=False) -> t.Generator[LoggingCursor, None, None]: with self.get_connection() as conn: cursor = conn.cursor(cursor_factory=LoggingCursor) try: yield cursor if commit: conn.commit() finally: cursor.close()
class ConnectionPool(object): def __init__(self, conn_params, minconn=5, maxconn=5): self._conn_params = conn_params.copy() self._conn_params['minconn'] = minconn self._conn_params['maxconn'] = maxconn self._conn_pool = None def initialize(self): self._conn_pool = ThreadedConnectionPool(**self._conn_params) @contextmanager def cursor(self): conn = self._conn_pool.getconn() cursor = conn.cursor() try: yield cursor conn.commit() except Exception: conn.rollback() raise finally: self._conn_pool.putconn(conn)
class PgConnectionPool: def __init__(self, *args, min_conns=1, keep_conns=10, max_conns=10, **kwargs): self._pool = ThreadedConnectionPool( min_conns, max_conns, *args, **kwargs) self._keep_conns = keep_conns def acquire(self): pool = self._pool conn = pool.getconn() pool.minconn = min(self._keep_conns, len(pool._used)) return conn def release(self, conn): self._pool.putconn(conn) def close(self): if hasattr(self, '_pool'): self._pool.closeall() __del__ = close
class PoolWrapper: """Exists to provide an acquire method for easy usage. pool = PoolWrapper(...) with pool.acquire() as conneciton: connection.execute(...) """ def __init__(self, max_pool_size: int, *, dsn): self._pool = ThreadedConnectionPool( 1, max_pool_size, dsn=dsn, cursor_factory=RealDictCursor, ) @contextmanager def acquire(self): try: connection = self._pool.getconn() yield connection finally: self._pool.putconn(connection)
def from_qid_author(qid: int, author: str, pool: ThreadedConnectionPool = None): """ Retrieve data from database and construct ``Answer`` using it. """ conn = pool.getconn() with conn.cursor() as curs: curs.execute( "SELECT body, upvotes, downvotes " "FROM answers " "WHERE author=%(username)s AND qid=%(qid)s", { 'username': author, 'qid': qid }) body, upvotes, downvotes = curs.fetchone() pool.putconn(conn) return Answer(author=author, body=body, upvotes=upvotes, downvotes=downvotes, qid=qid, pool=pool)
class ProcessSafePoolManager: def __init__(self, *args, **kwargs): self.last_seen_process_id = os.getpid() self.args = args self.kwargs = kwargs self._init() def _init(self): self._pool = ThreadedConnectionPool(*self.args, **self.kwargs) def getconn(self): current_pid = os.getpid() if not (current_pid == self.last_seen_process_id): self._init() print("New id is %s, old id was %s" % (current_pid, self.last_seen_process_id)) self.last_seen_process_id = current_pid return self._pool.getconn() def putconn(self, conn): return self._pool.putconn(conn)
class PostgresThreadPool: provides = ['db_connection_pool', 'postgres'] requires_configured = ['json_settings'] def __init__(self, settings): from psycopg2.pool import ThreadedConnectionPool dbsettings = settings['database'] self.pool = ThreadedConnectionPool( minconn=1, maxconn=settings['database']['conn_pool_size'], database=dbsettings['name'], user=dbsettings['username'], password=dbsettings['password'], host=dbsettings['host'], port=dbsettings.get('port') ) def getconn(self): return self.pool.getconn() def putconn(self, connection): return self.pool.putconn(connection)
class Pool: def __init__(self, ): self.pool = None def init_app(self, app): self.pool = ThreadedConnectionPool(minconn=1, maxconn=10, dsn=app.config["DATABASE_URL"]) app.teardown_appcontext(self.return_conn) def get_conn(self): if 'db_conn' not in g: g.db_conn = self.pool.getconn() return g.db_conn def return_conn(self, x): db_conn = g.pop('db_conn', None) if db_conn: self.pool.putconn(db_conn) def unwrap_pg_statement(self, func, *args, **kwargs): items = func(*args, **kwargs) if type(items) is str: return items, [] elif type(items) is tuple: if len(items) == 1: return items[0], [] else: return items[0], items[1] def execute(self, cursor_method="fetchall"): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): with self.get_conn() as conn: with conn.cursor(cursor_factory=RealDictCursor) as c: sql, values = self.unwrap_pg_statement( func, *args, **kwargs) c.execute(sql, values) if sql.strip().upper().startswith("SELECT"): if cursor_method == "fetchall": result = [x for x in c.fetchall()] elif cursor_method == "fetchone": result = c.fetchone() # print("SQL", sql) # print("FETCH ONE RESULT", result) if type(result) is list and not len(result): return [] return result return True return wrapper return decorator def executemany(self, func): @functools.wraps(func) def wrapper(*args, **kwargs): with self.get_conn() as conn: with conn.cursor() as c: sql, values = self.unwrap_pg_statement( func, *args, **kwargs) c.executemany(sql, values) return True return wrapper
class Database: def __init__(self, db_config, table_raw=None, max_connections=10): from psycopg2.pool import ThreadedConnectionPool self.table_raw = table_raw try: self.pool = ThreadedConnectionPool( minconn=1, maxconn=max_connections, dsn= "dbname={db_name} user={db_user} host={db_host} password={db_pass}" .format(**db_config)) except Exception: logger.exception("Error in db connection") sys.exit(1) logger.debug( "Connected to database: {host}".format(host=db_config['db_host'])) @contextmanager def getcursor(self, **kwargs): conn = self.pool.getconn() try: yield conn.cursor(**kwargs) conn.commit() except Exception as e: conn.rollback() raise e.with_traceback(sys.exc_info()[2]) finally: self.pool.putconn(conn) def insert(self, table, data_list, return_cols='id'): """ TODO: rename `id_col` -> `return_col` Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and n cols) for inserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] # Make sure data_list has content if len(data_list) == 0: # No need to continue return [] # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): logger.critical("Data must be a list of dicts") # Do not return here, let the exception handle the error that will be thrown when the query runs # Make sure return_cols is a list if not isinstance(return_cols, list): return_cols = [return_cols] # Make sure on_conflict_fields has data if len(return_cols) == 0 or return_cols[0] is None: # No need to continue logger.critical("`return_cols` cannot be None/empty") # TODO: raise some error here rather then returning None return None try: with self.getcursor() as cur: query = "INSERT INTO {table} ({fields}) VALUES {values} RETURNING {return_cols}"\ .format(table=table, fields='"{0}"'.format('", "'.join(data_list[0].keys())), values=','.join(['%s'] * len(data_list)), return_cols=', '.join(return_cols), ) query = cur.mogrify(query, [tuple(v.values()) for v in data_list]) cur.execute(query) return cur.fetchall() except Exception as e: logger.exception("Error inserting data") logger.debug("Error inserting data: {data}".format(data=data_list)) raise e.with_traceback(sys.exc_info()[2]) def upsert(self, table, data_list, on_conflict_fields, on_conflict_action='update', update_fields=None, return_cols='id'): """ Create a bulk upsert statement which is much faster (~6x in tests with 10k & 100k rows and n cols) for upserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] # Make sure data_list has content if len(data_list) == 0: # No need to continue return [] # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): logger.critical("Data must be a list of dicts") # TODO: raise some error here rather then returning None return None # Make sure on_conflict_fields is a list if not isinstance(on_conflict_fields, list): on_conflict_fields = [on_conflict_fields] # Make sure on_conflict_fields has data if len(on_conflict_fields) == 0 or on_conflict_fields[0] is None: # No need to continue logger.critical("Must pass in `on_conflict_fields` argument") # TODO: raise some error here rather then returning None return None # Make sure return_cols is a list if not isinstance(return_cols, list): return_cols = [return_cols] # Make sure on_conflict_fields has data if len(return_cols) == 0 or return_cols[0] is None: # No need to continue logger.critical("`return_cols` cannot be None/empty") # TODO: raise some error here rather then returning None return None # Make sure update_fields is a list/valid if on_conflict_action == 'update': if not isinstance(update_fields, list): update_fields = [update_fields] # If noting is passed in, set `update_fields` to all (data_list-on_conflict_fields) if len(update_fields) == 0 or update_fields[0] is None: update_fields = list( set(data_list[0].keys()) - set(on_conflict_fields)) # If update_fields is empty here that could only mean that all fields are set as conflict_fields if len(update_fields) == 0: logger.critical( "Not all the fields can be `on_conflict_fields` when doing an update" ) # TODO: raise some error here rather then returning None return None # If everything is good to go with the update fields fields_update_tmp = [] for key in data_list[0].keys(): fields_update_tmp.append('"{0}"="excluded"."{0}"'.format(key)) conflict_action_sql = 'UPDATE SET {update_fields}'\ .format(update_fields=', '.join(fields_update_tmp)) else: # Do nothing on conflict conflict_action_sql = 'NOTHING' try: with self.getcursor() as cur: query = """INSERT INTO {table} ({insert_fields}) SELECT {values} ON CONFLICT ({on_conflict_fields}) DO {conflict_action_sql} RETURNING {return_cols} """.format( table=table, insert_fields='"{0}"'.format('", "'.join( data_list[0].keys())), values=','.join(['unnest(%s)'] * len(data_list[0])), on_conflict_fields=', '.join(on_conflict_fields), conflict_action_sql=conflict_action_sql, return_cols=', '.join(return_cols), ) # Get all the values for each row and create a lists of lists values = [list(v.values()) for v in data_list] # Transpose list of lists values = list(map(list, zip(*values))) query = cur.mogrify(query, values) cur.execute(query) return cur.fetchall() except Exception as e: logger.exception("Error inserting data") logger.debug("Error inserting data: {data}".format(data=data_list)) raise e.with_traceback(sys.exc_info()[2]) def update(self, table, data_list, matched_field=None, return_cols='id'): """ Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and 4 cols) for inserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ if matched_field is None: # Assume the id field logger.info("Matched field not defined, assuming the `id` field") matched_field = 'id' # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] if len(data_list) == 0: # No need to continue return [] # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): logger.critical("Data must be a list of dicts") # Do not return here, let the exception handle the error that will be thrown when the query runs try: with self.getcursor() as cur: query_list = [] # TODO: change to return data from the database, not just what you passed in return_list = [] for row in data_list: if row.get(matched_field) is None: logger.debug( "Cannot update row. Missing field {field} in data {data}" .format(field=matched_field, data=row)) logger.error( "Cannot update row. Missing field {field} in data". format(field=matched_field)) continue # Pull matched_value from data to be updated and remove that key matched_value = row.get(matched_field) del row[matched_field] query = "UPDATE {table} SET {data} WHERE {matched_field}=%s RETURNING {return_cols}"\ .format(table=table, data=','.join("%s=%%s" % u for u in row.keys()), matched_field=matched_field, return_cols=return_cols ) values = list(row.values()) values.append(matched_value) query = cur.mogrify(query, values) query_list.append(query) return_list.append(matched_value) finial_query = b';'.join(query_list) cur.execute(finial_query) return return_list except Exception as e: logger.exception("Error updating data") logger.debug("Error updating data: {data}".format(data=data_list)) raise e.with_traceback(sys.exc_info()[2])
class MT_Terrain(): def __init__(self, n_querier, n_calc, working_area, dataset, sample_ratio, plot): self.workers_query_result_q = mp.Queue() DSN = "postgresql://*****:*****@192.168.184.102/postgres" self.tcp = ThreadedConnectionPool(1, 100, DSN) conn = self.tcp.getconn() cur = conn.cursor() self.querier = [] self.calc = [] self.go = mp.Value('b', True) self.dataset = dataset self.sample_ratio = sample_ratio self.working_area = working_area self.plot = plot self.processed_link = mp.Value('i', 0) self._counter_lock = mp.Lock() tf = terrain_RF(cur=cur, dataset=self.dataset, working_area=self.working_area) buildings = tf.get_buildings() self.n_buildings = len(buildings) / sample_ratio self.sim_name = "%s_%d_%d_" % ( (self.dataset, self.n_buildings, time.time())) shuffle(buildings) buildings = buildings[:self.n_buildings] self.write_node_latlon(buildings) gid, h, centroid_x, centroid_y = zip(*buildings) self.dict_h = dict(zip(gid, h)) self.tcp.putconn(conn, close=True) buildings_pair = set(itertools.combinations(gid, 2)) self.link_filename = "../data/%s_links.csv" % (self.sim_name) self.tot_link = len(buildings_pair) print "%d links left to estimate" % len(buildings_pair) chunk_size = self.tot_link / n_querier chunks = list(chunked(buildings_pair, chunk_size)) with open(self.link_filename + "_0", 'a') as fl: print >> fl, "b1,b2,status,loss,status_downscale,loss_downscale, status_srtm, loss_srtm" self.start_time = time.time() for i in range(n_querier): t = mp.Process(target=self.queryWorker, args=[i, chunks[i]]) self.querier.append(t) t.daemon = True t.start() for i in range(n_calc): t = mp.Process(target=self.calcWorker, args=[i]) self.calc.append(t) t.daemon = True t.start() t = mp.Process(target=self.monitor) t.daemon = True t.start() [self.querier[i].join() for i in range(n_querier)] print "Finished Query" self.go.value = False print "Set false go" [self.calc[i].join() for i in range(n_calc)] def write_node_latlon(self, buildings): gid, h, centroid_x, centroid_y = zip(*buildings) coords = zip(gid, centroid_x, centroid_y) latlong_filename = "../data/%s_latlong.csv" % (self.sim_name) with open(latlong_filename, 'w') as f: print >> f, "id,lat,lon" for l in coords: print >> f, "%s,%f,%f" % (l) def monitor(self): time.sleep(1) widgets = [ 'Test: ', progressbar.Percentage(), ' ', progressbar.Bar(marker='0', left='[', right=']'), ' ', progressbar.ETA(), ' ', progressbar.FileTransferSpeed(unit="Link"), ] bar = progressbar.ProgressBar(widgets=widgets, max_value=self.tot_link).start() while (self.go.value): bar.update(self.processed_link.value, force=True) time.sleep(1) bar.finish() def queryWorker(self, worker_id, buildings_pairs): conn = self.tcp.getconn() cur = conn.cursor() tf = terrain_RF(cur=cur, dataset=self.dataset, working_area=self.working_area) tf_srtm = terrain_RF(cur=cur, dataset=self.dataset + "_srtm", working_area=self.working_area) for buildings_pair in buildings_pairs: id1 = buildings_pair[0] id2 = buildings_pair[1] try: profile = tf.profile_osm(id1, id2) except ProfileException: profile = None try: profile_srtm = tf_srtm.profile_osm(id1, id2) except ProfileException: profile_srtm = None result = { "id1": id1, "id2": id2, "profile": profile, "profile_srtm": profile_srtm, "p1": (0.0, self.dict_h[id1]), "p2": (tf.distance(id1, id2), self.dict_h[id2]) } self.workers_query_result_q.put(result) self.tcp.putconn(conn, close=True) def calcWorker(self, worker_id): while (self.go.value): # Take ORDER if self.workers_query_result_q.qsize() > 3: print "Warning: the queue is containing %d elements" % ( self.workers_query_result_q.qsize()) order = self.workers_query_result_q.get() profile = order["profile"] profile_srtm = order["profile_srtm"] id1 = order["id1"] id2 = order["id2"] # Normal profile try: link = Link(profile) loss, status = link.loss_calculator() except (ZeroDivisionError, ProfileException), e: loss = 0 status = -1 # Downscaled profile try: link_ds = Link(profile) loss_ds, status_ds = link_ds.loss_calculator(downscale=3) except (ZeroDivisionError, ProfileException), e: loss_ds = 0 status_ds = -1 # SRTM Profile try: link_srtm = Link(profile_srtm) loss_srtm, status_srtm = link_srtm.loss_calculator() except ProfileException, e: status_srtm = 1 loss_srtm = 0
class BatchCopy(object): def __init__(self, dw_conf, event_path, worker_id, table_names, redis_conf, logger, sep='\x02'): self._tables = {} self._pool = ThreadedConnectionPool(1, 2, host=dw_conf.get("host"), port=dw_conf.get("port"), database=dw_conf.get("database"), user=dw_conf.get("user"), password=dw_conf.get("password")) self._redis = redis_conf self._logger = logger self._sep = sep cur = self._pool.getconn().cursor(cursor_factory=RealDictCursor) for table_name in table_names: worker_dir = "%s/%s" % (event_path, worker_id) if not os.path.exists(worker_dir): os.makedirs(worker_dir) file_path = "%s/%s.udw" % (worker_dir, table_name) self._tables[table_name] = { "columns": {}, "path": file_path, "file": open(file_path, 'a+'), "data": [] } sql = "SELECT ordinal_position, column_name FROM information_schema.columns " \ "WHERE table_name = '%s'" % table_name cur.execute(sql) columns = cur.fetchall() if len(columns) == 0: self._logger.error(table_name) continue for column in columns: self._tables[table_name]["columns"][ column["column_name"]] = column["ordinal_position"] cur.close() def json2copy(self, event): values = [] hash_values = {} for k, v in event.value.items(): if type(v) == int: hash_values[self._tables[event.key]["columns"][k]] = str(v) if type(v) == str: hash_values[self._tables[event.key]["columns"][k]] = v if type(v) == list or type(v) == dict: hash_values[self._tables[event.key]["columns"] [k]] = "'%s'" % str(v) range_stop = len(hash_values) + 1 for i in range(1, range_stop): values.append(hash_values[i]) return self._sep.join(values) def writelines(self, events): for event in events: try: row = self.json2copy(event) except Exception as e: self._logger.error(e) self._logger.error(self._tables[event.key]) else: self._tables[event.key]["data"].append(row + '\n') for v in self._tables.values(): if not isinstance(v, dict): continue v["file"].writelines(v["data"]) del v["data"][:] def redis_sieve(self, db, values): rc = RedisConnector(self._redis, db) values.seek(0) values = values.read().split('\n')[:-1] return rc.multi_add(values) def copy_sink(self): for k, v in self._tables.items(): result = self.redis_sieve(k, v['file']) if isinstance(result, list): result = '\n'.join(result) self._logger.info('duplicate data') v['file'].close() os.remove(v['path']) v['file'] = open(v['path'], 'a+') v['file'].write(result) # 无重复时 v["file"].seek(0) try: conn = self._pool.getconn() cur = conn.cursor(cursor_factory=RealDictCursor) cur.copy_from(v["file"], k, sep=self._sep) conn.commit() # 数据类型错误或缺少字段时为DataError except psycopg2.DataError as e: self._logger.error(e.message) v["file"].seek(0) with open( '/data/log/unresolved_data/cleaned#' + k + '#' + datetime.now().strftime("%Y%m%d%H%M%S"), 'w') as err_log: err_log.write(v["file"].read()) # 数据库响应速度慢,报错返回失败,返回为空 except psycopg2.DatabaseError as e: self._logger.error(e) v["file"].seek(0) with open( '/data/log/unresolved_data/cleaned#' + k + '#' + datetime.now().strftime("%Y%m%d%H%M%S"), 'w') as interrupted: interrupted.write(v["file"].read()) sleep(300) finally: cur.close() self._pool.putconn(conn) self.flush_copy() def flush_copy(self): for v in self._tables.values(): if not isinstance(v, dict): continue v["file"].close() os.remove(v["path"]) v["file"] = open(v["path"], 'a+') def close(self): for v in self._tables.values(): if not isinstance(v, dict): continue v["file"].close() os.remove(v["path"]) if not self._pool.closed: self._pool.closeall()
else: cur.execute(sql, (ids,)) _original_hosts = [] cnx = cnxpool.getconn() with cnx.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(r"""SELECT original_host, display_name as original_host_display_name, region as original_host_region FROM findopendata.original_hosts WHERE enabled ORDER BY display_name""") _original_hosts = [row for row in cursor.fetchall()] cnxpool.putconn(cnx) @app.route('/api/original-hosts', methods=['GET']) def original_hosts(): return jsonify(_original_hosts) _data_formats = [] cnx = cnxpool.getconn() with cnx.cursor(cursor_factory=RealDictCursor) as cursor: cursor.execute(r"""SELECT DISTINCT format FROM findopendata.package_files WHERE format != '' """) format_results = cursor.fetchall() pre_filters = ["XLS", "XLSX", "XSL", "ZIP", "PDF", "ODS", "CSV", "TEXT", "JSON", "XML", "API", "HTML", "TXT", "DOCX", "PPTX", "TAR"]
class Database: def __init__(self, db_config, table_raw=None, max_connections=10): from psycopg2.pool import ThreadedConnectionPool self.table_raw = table_raw try: self.pool = ThreadedConnectionPool(minconn=1, maxconn=max_connections, dsn="dbname={db_name} user={db_user} host={db_host} password={db_pass}" .format(**db_config)) except Exception: logger.exception("Error in db connection") sys.exit(1) logger.debug("Connected to database: {host}".format(host=db_config['db_host'])) @contextmanager def getcursor(self): conn = self.pool.getconn() try: yield conn.cursor() conn.commit() except Exception as e: conn.rollback() raise e.with_traceback(sys.exc_info()[2]) finally: self.pool.putconn(conn) def insert(self, table, data_list, id_col='id'): """ TODO: rename `id_col` -> `return_col` Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and 4 cols) for inserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] if len(data_list) == 0: # No need to continue return [] # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): logger.critical("Data must be a list of dicts") # Do not return here, let the exception handle the error that will be thrown when the query runs try: with self.getcursor() as cur: query = "INSERT INTO {table} ({fields}) VALUES {values} RETURNING {id_col}"\ .format(table=table, fields='{0}{1}{0}'.format('"', '", "'.join(data_list[0].keys())), values=','.join(['%s'] * len(data_list)), id_col=id_col ) query = cur.mogrify(query, [tuple(v.values()) for v in data_list]) cur.execute(query) return [t[0] for t in cur.fetchall()] except Exception as e: logger.exception("Error inserting data") logger.debug("Error inserting data: {data}".format(data=data_list)) raise e.with_traceback(sys.exc_info()[2]) def update(self, table, data_list, matched_field=None, return_col='id'): """ Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and 4 cols) for inserting data then executemany() TODO: Is there a limit of length the query can be? If so handle it. """ if matched_field is None: # Assume the id field logger.info("Matched field not defined, assuming the `id` field") matched_field = 'id' # Make sure that `data_list` is a list if not isinstance(data_list, list): data_list = [data_list] if len(data_list) == 0: # No need to continue return [] # Data in the list must be dicts (just check the first one) if not isinstance(data_list[0], dict): logger.critical("Data must be a list of dicts") # Do not return here, let the exception handle the error that will be thrown when the query runs try: with self.getcursor() as cur: query_list = [] # TODO: change to return data from the database, not just what you passed in return_list = [] for row in data_list: if row.get(matched_field) is None: logger.debug("Cannot update row. Missing field {field} in data {data}" .format(field=matched_field, data=row)) logger.error("Cannot update row. Missing field {field} in data".format(field=matched_field)) continue # Pull matched_value from data to be updated and remove that key matched_value = row.get(matched_field) del row[matched_field] query = "UPDATE {table} SET {data} WHERE {matched_field}=%s RETURNING {return_col}"\ .format(table=table, data=','.join("%s=%%s" % u for u in row.keys()), matched_field=matched_field, return_col=return_col ) values = list(row.values()) values.append(matched_value) query = cur.mogrify(query, values) query_list.append(query) return_list.append(matched_value) finial_query = b';'.join(query_list) cur.execute(finial_query) return return_list except Exception as e: logger.exception("Error updating data") logger.debug("Error updating data: {data}".format(data=data_list)) raise e.with_traceback(sys.exc_info()[2])
class DbPool(object): """DB class that makes connection transparently. Thread-safe - every thread get its own database connection. Not meant to be used directly, there is no reason to have more than one instance - global variable Db - in this module.""" def __init__(self, config): """Configures the Db, connection is not created yet. @param config: instance of config.NotaryServerConfig.""" self.host = config.db_host self.port = config.db_port self.user = config.db_user self.password = config.db_password self.db_name = config.db_name self.min_connections = config.db_min_conn self.max_connections = config.db_max_conn self.pool = ThreadedConnectionPool( minconn = self.min_connections, maxconn = self.max_connections, host = self.host, port = self.port, user = self.user, password = self.password, database = self.db_name) def cursor(self, **kwargs): """Creates and returns cursor for current thread's connection. Cursor is a "dict" cursor, so you can access the columns by names (not just indices), e.g.: cursor.execute("SELECT id, name FROM ... WHERE ...", sql_args) row = cursor.fetchone() id = row['id'] Server-side cursors (named cursors) should be closed explicitly. @param kwargs: currently string parameter 'name' is supported. Named cursors are for server-side cursors, which are useful when fetching result of a large query via fetchmany() method. See http://initd.org/psycopg/docs/usage.html#server-side-cursors """ return self.connection().cursor(cursor_factory=DictCursor, **kwargs) def connection(self): """Return connection for this thread""" return self.pool.getconn(id(threading.current_thread())) def commit(self): """Commit all the commands in this transaction in this thread's connection. If errors (e.g. duplicate key) arose, this will cause transaction rollback. """ self.connection().commit() def rollback(self): """Rollback last transaction on this thread's connection""" self.connection().rollback() def putconn(self): """Put back connection used by this thread. Necessary upon finishing of spawned threads, otherwise new threads won't get connection if the pool is depleted.""" conn = self.connection() self.pool.putconn(conn, id(threading.current_thread())) def close(self): """Close connection.""" self.connection().close()
class GetEachMatchOdds: defaultEncoding = 'utf-8' def __init__(self, loglevel, db_config_file, section_name): self.pool_size = 2 self.min_db_conn_size = 2 logging.config.fileConfig('logger_test.conf') self.logger = logging.getLogger(loglevel) self.config = configparser.ConfigParser() self.config.read(db_config_file) self.host = self.config[section_name]['host'] self.port = self.config[section_name]['port'] self.user = self.config[section_name]['user'] self.password = self.config[section_name]['password'] self.database = self.config[section_name]['database'] self.conn_pool = ThreadedConnectionPool(self.min_db_conn_size, self.pool_size, host=self.host, port=self.port, user=self.user, database=self.database) self.main_db_conn = db_connecter.DBConnecter(db_config_file, section_name) self.main_db_conn.connect() pass def basic_get_match_info_task(self, queue): headers = { 'User-Agent': 'Mozilla/4.0 (compatible; MSIE 5.5; Windows NT)', 'Referer': 'https://developer.mozilla.org/en-US/docs/Web/JavaScript' } euro_sql_str = '''insert into euro_odds (match_id, result, company, priodds3, priodds1, priodds0, nowodds3, nowodds1, nowodds0, prichance3, prichance1, prichance0, nowchance3, nowchance1, nowchance0, priyrr, nowyrr, prikelly3, prikelly1, prikelly0, nowkelly3, nowkelly1, nowkelly0 ) values ( %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s ) ''' asia_sql_str = '''insert into asia_odds (match_id, result, company, priodds3, priconcede, priodds0, nowodds3, nowconcede, nowodds0 ) values ( %s, %s, %s, %s, %s, %s, %s, %s, %s ) ''' threadname = threading.currentThread().getName() while True: rows = queue.get() db_conn = self.conn_pool.getconn() curs = db_conn.cursor() for row in rows: while True: try: self.logger.info( threadname + ":" + str([row[0], row[1], row[2], row[3], row[4]])) match_id = row[0] if row[1] > row[2]: result = 3 elif row[1] == row[2]: result = 1 else: result = 0 if re.search(r'^http:', row[3]): asia_data_url = row[3] else: asia_data_url = 'http:' + row[3] if re.search(r'^http:', row[4]): euro_data_url = row[4] else: euro_data_url = 'http:' + row[4] euro_r = requests.get(euro_data_url, headers=headers, timeout=5) euro_r.encoding = euro_r.apparent_encoding sleep(0.5) asia_r = requests.get(asia_data_url, headers=headers, timeout=5) asia_r.encoding = asia_r.apparent_encoding _euro_html = html.fromstring(euro_r.text) if (len( _euro_html.xpath( '//table[@class="pl_table_data"]')) > 0): all_euro_odds = self.parse_euro_new_html(euro_r) all_asia_odds = self.parse_asia_new_html(asia_r) for e_odds in all_euro_odds: curs.execute( euro_sql_str, tuple([match_id, result] + e_odds)) for a_odds in all_asia_odds: curs.execute( asia_sql_str, tuple([match_id, result] + a_odds)) else: all_euro_odds = self.parse_euro_old_html(euro_r) all_asia_odds = self.parse_asia_old_html(asia_r) for e_odds in all_euro_odds: curs.execute( euro_sql_str, tuple([match_id, result] + e_odds)) for a_odds in all_asia_odds: curs.execute( asia_sql_str, tuple([match_id, result] + a_odds)) db_conn.commit() except GetHtmlFailed as err: self.logger.error( "Get Html Failed: {0}. 2 seconds later retry to get...." .format(err)) sleep(2) continue except Exception as err: self.logger.error( "Get Html Failed: {0}. 5 seconds later retry to get...." .format(err)) sleep(5) continue break self.conn_pool.putconn(db_conn) queue.task_done() def basic_work(self, group_size, start_offset, end_offset): self.logger.info("Start get matches infomation...") start_time = datetime.datetime.now() queue = Queue.Queue(maxsize=self.pool_size) killer = GracefulKiller() for i in range(self.pool_size): t = threading.Thread(target=self.basic_get_match_info_task, name='Task-' + str(i), args=(queue, )) t.daemon = True t.start() current_offset = start_offset headers = { 'User-Agent': 'Mozilla/4.0 (compatible; MSIE 5.5; Windows NT)', 'Referer': 'https://developer.mozilla.org/en-US/docs/Web/JavaScript' } while (current_offset <= end_offset): sql_str = 'select id, home_team_score, away_team_score, asia_data_url, euro_data_url from matchs limit %s offset %s' self.main_db_conn.select_record(sql_str, (group_size, current_offset)) rows = self.main_db_conn.fetchall() queue.put(rows) if killer.kill_now: self.logger.info( "The programing will stop after all task done, and the current offset is: %s", current_offset + group_size) break current_offset += group_size if not killer.kill_now: self.logger.info("Get matches info completed!") self.main_db_conn.close() queue.join() end_time = datetime.datetime.now() last_time = end_time - start_time self.logger.info( "Get matches infomation finished, the program lasted for %s seconds", last_time.total_seconds()) def parse_euro_new_html(self, response): _html = html.fromstring(response.text) company_names = _html.xpath('//span[@class="quancheng"]/text()') tables = _html.xpath('//table[@class="pl_table_data"]') if (len(_html.xpath('//table[@class="pl_table_data"]')) == 0): raise GetHtmlFailed("Can't get Euro odds") if (len(company_names) < 1): return [] odds_arr = [] for tb in tables[3::4]: trs = tb.xpath('./tbody/tr') pri_odds = trs[0].xpath('./td/text()') last_odds = trs[1].xpath('./td/text()') odds_arr.append([pri_odds, last_odds]) chance_arr = [] for tb in tables[4::4]: trs = tb.xpath('./tbody/tr') pri_chance = trs[0].xpath('./td/text()') last_chance = trs[1].xpath('./td/text()') chance_arr.append([pri_chance, last_chance]) yrr_arr = [] for tb in tables[5::4]: trs = tb.xpath('./tbody/tr') pri_yrr = trs[0].xpath('./td/text()') last_yrr = trs[1].xpath('./td/text()') yrr_arr.append([pri_yrr, last_yrr]) kelly_arr = [] for tb in tables[6::4]: trs = tb.xpath('./tbody/tr') pri_kelly = trs[0].xpath('./td/text()') last_kelly = trs[1].xpath('./td/text()') kelly_arr.append([pri_kelly, last_kelly]) all_company_odds = [] for c, odds, chance, yrr, kelly in zip(company_names, odds_arr, chance_arr, yrr_arr, kelly_arr): all_company_odds.append([c.encode(GetEachMatchOdds.defaultEncoding),odds[0][0], odds[0][1], odds[0][2], \ odds[1][0], odds[1][1], odds[1][2], chance[0][0], chance[0][1], chance[0][2], chance[1][0], chance[1][1],\ chance[1][2], yrr[0][0], yrr[1][0], kelly[0][0], kelly[0][1], kelly[0][2], kelly[1][0], kelly[1][1], kelly[1][2]]) return all_company_odds def parse_euro_old_html(self, response): _html = html.fromstring(response.text) trs = _html.xpath('//table[@id="datatb"]/tr') all_company_odds = [] if (len(_html.xpath('//table[@id="datatb"]')) == 0): raise GetHtmlFailed("Can't get Euro odds") if (len(trs) < 4): return [] for elem, next_elem in zip(trs[2::2], trs[3::2] + [trs[2]]): elem_tds = elem.xpath('./td') next_elem_tds = next_elem.xpath('./td') if len(elem_tds) > 12: if (len(elem_tds[1].xpath('./a/text()')) > 0): company = elem_tds[1].xpath('./a/text()')[0].encode( GetEachMatchOdds.defaultEncoding) else: company = elem_tds[1].text.encode( GetEachMatchOdds.defaultEncoding) all_company_odds.append([company, elem_tds[2].text, elem_tds[3].text, elem_tds[4].text, next_elem_tds[0].text, \ next_elem_tds[1].text, next_elem_tds[2].text, elem_tds[5].text, elem_tds[6].text, elem_tds[7].text, next_elem_tds[3].text, \ next_elem_tds[4].text, next_elem_tds[5].text, elem_tds[8].text, next_elem_tds[6].text, elem_tds[9].xpath('./span/text()')[0], \ elem_tds[10].xpath('./span/text()')[0], elem_tds[11].text, next_elem_tds[7].xpath('./span/text()')[0], \ next_elem_tds[8].xpath('./span/text()')[0], next_elem_tds[9].text]) return all_company_odds def parse_asia_new_html(self, response): _html = html.fromstring(response.text) company_names = _html.xpath('//span[@class="quancheng"]/text()') if (len(_html.xpath('//table[@class="pl_table_data"]')) == 0): raise GetHtmlFailed("Can't get Asia odds") if (len(company_names) < 1): return [] tables = _html.xpath('//table[@class="pl_table_data"]') now_odds_arr = [] pri_odds_arr = [] for tb1, tb2 in zip(tables[2::2], tables[3::2]): tb1_tds = tb1.xpath('./tbody/tr/td/text()') tb2_tds = tb2.xpath('./tbody/tr/td/text()') now_odds_arr.append(tb1_tds[0:3]) pri_odds_arr.append(tb2_tds[0:3]) all_company_odds = [] for c, pri_odds, now_odds in zip(company_names, pri_odds_arr, now_odds_arr): if (len(pri_odds) < 3) or (len(now_odds) < 3): continue all_company_odds.append([c.encode(GetEachMatchOdds.defaultEncoding), pri_odds[0].encode(GetEachMatchOdds.defaultEncoding), \ pri_odds[1].replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), pri_odds[2].encode(GetEachMatchOdds.defaultEncoding), \ now_odds[0].replace(u'\u2193', '').replace(u'\u2191', '').encode(GetEachMatchOdds.defaultEncoding), \ now_odds[1].replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), \ now_odds[2].replace(u'\u2193', '').replace(u'\u2191', '').encode(GetEachMatchOdds.defaultEncoding)]) return all_company_odds def parse_asia_old_html(self, response): _html = html.fromstring(response.text) trs = _html.xpath('//table[@id="datatb"]//tr') if (len(_html.xpath('//table[@id="datatb"]')) == 0): raise GetHtmlFailed("Can't get Asia odds") all_company_odds = [] number_re = re.compile(r'\d\.\d') for tr in trs[1:-3]: tds = tr.xpath('./td') if (len(tds) > 8): pri_home_odds = tds[6].text pri_away_odds = tds[8].text if len(tds[2].xpath('./span/text()')) < 1 or len( tds[4].xpath('./span/text()')) < 1: continue new_home_odds = tds[2].xpath('./span/text()')[0] new_away_odds = tds[4].xpath('./span/text()')[0] if not number_re.search(pri_home_odds) or not number_re.search(pri_away_odds) or not number_re.search(new_home_odds) or \ not number_re.search(new_away_odds) or not tds[1].text or not tds[3].text or not tds[7].text: continue all_company_odds.append([tds[1].text.replace(u'\ufffd', '').encode(GetEachMatchOdds.defaultEncoding), \ pri_home_odds, tds[7].text.replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), \ pri_away_odds, new_home_odds, tds[3].text.replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), new_away_odds]) return all_company_odds
class ExampleSearch(object): """Class for performing the Example search. Example search is the neighborhood search that demonstrates how to construct and query a spatial database based on URL search string, extract geometries from the result, associate various styles with them and return the response back to the client Valid Inputs are: q=pacific heights neighborhood=pacific heights """ def __init__(self): """Inits ExampleSearch. Initializes the logger "ge_search". Initializes templates for kml, json, placemark templates for the KML/JSONP output. Initializes parameters for establishing a connection to the database. """ self.utils = utils.SearchUtils() constants = geconstants.Constants() configs = self.utils.GetConfigs( os.path.join(geconstants.SEARCH_CONFIGS_DIR, "ExampleSearch.conf")) style_template = self.utils.style_template self._jsonp_call = self.utils.jsonp_functioncall self._geom = """ <name>%s</name> <styleUrl>%s</styleUrl> <Snippet>%s</Snippet> <description>%s</description> %s\ """ self._json_geom = """ { "name": "%s", "Snippet": "%s", "description": "%s", %s }\ """ self._placemark_template = self.utils.placemark_template self._kml_template = self.utils.kml_template self._json_template = self.utils.json_template self._json_placemark_template = self.utils.json_placemark_template self._example_query_template = (Template(constants.example_query)) self.logger = self.utils.logger self._user = configs.get("user") self._hostname = configs.get("host") self._port = configs.get("port") self._database = configs.get("databasename") if not self._database: self._database = constants.defaults.get("example.database") self._pool = ThreadedConnectionPool( int(configs.get("minimumconnectionpoolsize")), int(configs.get("maximumconnectionpoolsize")), database=self._database, user=self._user, host=self._hostname, port=int(self._port)) self._style = style_template.substitute( balloonBgColor=configs.get("balloonstyle.bgcolor"), balloonTextColor=configs.get("balloonstyle.textcolor"), balloonText=configs.get("balloonstyle.text"), iconStyleScale=configs.get("iconstyle.scale"), iconStyleHref=configs.get("iconstyle.href"), lineStyleColor=configs.get("linestyle.color"), lineStyleWidth=configs.get("linestyle.width"), polyStyleColor=configs.get("polystyle.color"), polyStyleColorMode=configs.get("polystyle.colormode"), polyStyleFill=configs.get("polystyle.fill"), polyStyleOutline=configs.get("polystyle.outline"), listStyleHref=configs.get("iconstyle.href")) def RunPGSQLQuery(self, query, params): """Submits the query to the database and returns tuples. Note: variables placeholder must always be %s in query. Warning: NEVER use Python string concatenation (+) or string parameters interpolation (%) to pass variables to a SQL query string. e.g. SELECT vs_url FROM vs_table WHERE vs_name = 'default_ge'; query = "SELECT vs_url FROM vs_table WHERE vs_name = %s" parameters = ["default_ge"] Args: query: SQL SELECT statement. params: sequence of parameters to populate into placeholders. Returns: Results as list of tuples (rows of fields). Raises: psycopg2.Error/Warning in case of error. """ con = None cursor = None query_results = [] query_status = False self.logger.debug( "Querying the database %s, at port %s, as user %s on" "hostname %s" % (self._database, self._port, self._user, self._hostname)) try: con = self._pool.getconn() if con: cursor = con.cursor() cursor.execute(query, params) for row in cursor: if len(row) == 1: query_results.append(row[0]) else: query_results.append(row) query_status = True except psycopg2.pool.PoolError as e: self.logger.error("Exception while querying the database %s, %s", self._database, e) raise exceptions.PoolConnectionException( "Pool Error - Unable to get a connection from the pool.") except psycopg2.Error as e: self.logger.error("Exception while querying the database %s, %s", self._database, e) finally: if con: self._pool.putconn(con) return query_status, query_results def RunExampleSearch(self, search_query, response_type): """Performs a query search on the 'san_francisco_neighborhoods' table. Args: search_query: the query to be searched, in smallcase. response_type: Response type can be KML or JSONP, depending on the client Returns: tuple containing total_example_results: Total number of rows returned from querying the database. example_results: Query results as a list """ example_results = [] params = ["%" + entry + "%" for entry in search_query.split(",")] accum_func = self.utils.GetAccumFunc(response_type) example_query = self._example_query_template.substitute( FUNC=accum_func) query_status, query_results = self.RunPGSQLQuery(example_query, params) total_example_results = len(query_results) if query_status: for entry in xrange(total_example_results): results = {} name = query_results[entry][4] snippet = query_results[entry][3] styleurl = "#placemark_label" description = ("The total area in decimal degrees of " + query_results[entry][4] + " is: " + str(query_results[entry][1]) + "<![CDATA[<br/>]]>") description += ("The total perimeter in decimal degrees of " + query_results[entry][4] + " is: " + str(query_results[entry][2])) geom = str(query_results[entry][0]) results["name"] = name results["snippet"] = snippet results["styleurl"] = styleurl results["description"] = description results["geom"] = geom results["geom_type"] = str(query_results[entry][5]) example_results.append(results) return total_example_results, example_results def ConstructKMLResponse(self, search_results, original_query): """Prepares KML response. KML response has the below format: <kml> <Folder> <name/> <StyleURL> --- </StyleURL> <Point> <coordinates/> </Point> </Folder> </kml> Args: search_results: Query results from the searchexample database original_query: Search query as entered by the user Returns: kml_response: KML formatted response """ search_placemarks = "" kml_response = "" lookat_info = "" set_first_element_lookat = True # folder name should include the query parameter(q) if 'displayKeys' # is present in the URL otherwise not. if self.display_keys_string: folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) else: folder_name = ("Grouped results:<![CDATA[<br/>]]> (%s)" % (str(len(search_results)))) fly_to_first_element = str(self.fly_to_first_element).lower() == "true" for result in search_results: geom = "" placemark = "" geom = self._geom % (result["name"], result["styleurl"], result["snippet"], result["description"], result["geom"]) # Add <LookAt> for POINT geometric types only. # TODO: Check if <LookAt> can be added for # LINESTRING and POLYGON types. if result["geom_type"] != "POINT": set_first_element_lookat = False if fly_to_first_element and set_first_element_lookat: lookat_info = self.utils.GetLookAtInfo(result["geom"]) set_first_element_lookat = False placemark = self._placemark_template.substitute(geom=geom) search_placemarks += placemark kml_response = self._kml_template.substitute( foldername=folder_name, style=self._style, lookat=lookat_info, placemark=search_placemarks) self.logger.info("KML response successfully formatted") return kml_response def ConstructJSONPResponse(self, search_results, original_query): """Prepares JSONP response. { "Folder": { "name": "Latitude X Longitude Y", "Placemark": { "Point": { "coordinates": "X,Y" } } } } Args: search_results: Query results from the searchexample table original_query: Search query as entered by the user Returns: jsonp_response: JSONP formatted response """ search_placemarks = "" search_geoms = "" geoms = "" json_response = "" jsonp_response = "" folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" % (original_query, str(len(search_results)))) for count, result in enumerate(search_results): geom = "" geom = self._json_geom % (result["name"], result["snippet"], result["description"], result["geom"][1:-1]) if count < (len(search_results) - 1): geom += "," geoms += geom if len(search_results) == 1: search_geoms = geoms else: search_geoms = "[" + geoms + "]" search_placemarks = self._json_placemark_template.substitute( geom=search_geoms) json_response = self._json_template.substitute( foldername=folder_name, json_placemark=search_placemarks) # Escape single quotes from json_response. json_response = json_response.replace("'", "\\'") jsonp_response = self._jsonp_call % (self.f_callback, json_response) self.logger.info("JSONP response successfully formatted") return jsonp_response def HandleSearchRequest(self, environ): """Fetches the search tokens from form and performs the example search. Args: environ: A list of environment variables as supplied by the WSGI interface to the example search application interface. Returns: search_results: A KML/JSONP formatted string which contains search results. """ search_results = "" search_status = False # Fetch all the attributes provided by the user. parameters = self.utils.GetParameters(environ) response_type = self.utils.GetResponseType(environ) # Retrieve the function call back name for JSONP response. self.f_callback = self.utils.GetCallback(parameters) original_query = self.utils.GetValue(parameters, "q") # Fetch additional query parameters 'flyToFirstElement' and # 'displayKeys' from URL. self.fly_to_first_element = self.utils.GetValue( parameters, "flyToFirstElement") self.display_keys_string = self.utils.GetValue(parameters, "displayKeys") if not original_query: # Extract 'neighborhood' parameter from URL try: form = cgi.FieldStorage(fp=environ["wsgi.input"], environ=environ) original_query = form.getvalue("neighborhood") except AttributeError as e: self.logger.debug("Error in neighborhood query %s" % e) if original_query: (search_status, search_results) = self.DoSearch(original_query, response_type) else: self.logger.debug("Empty or incorrect search query received") if not search_status: folder_name = "No results were returned." search_results = self.utils.NoSearchResults( folder_name, self._style, response_type, self.f_callback) return (search_results, response_type) def DoSearch(self, original_query, response_type): """Performs the example search and returns the results. Args: original_query: A string containing the search query as entered by the user. response_type: Response type can be KML or JSONP, depending on the client. Returns: tuple containing search_status: Whether search could be performed. search_results: A KML/JSONP formatted string which contains search results. """ search_status = False search_results = "" query_results = "" total_results = 0 search_query = original_query.strip().lower() if len(search_query.split(",")) > 2: self.logger.warning("Extra search parameters ignored:%s" % (",".join(search_query.split(",")[2:]))) search_query = ",".join(search_query.split(",")[:2]) original_query = ",".join(original_query.split(",")[:2]) total_results, query_results = self.RunExampleSearch( search_query, response_type) self.logger.info("example search returned %s results" % total_results) if total_results > 0: if response_type == "KML": search_results = self.ConstructKMLResponse( query_results, original_query) search_status = True elif response_type == "JSONP": search_results = self.ConstructJSONPResponse( query_results, original_query) search_status = True else: # This condition may not occur, # as response_type is either KML or JSONP self.logger.debug("Invalid response type %s" % response_type) return search_status, search_results def __del__(self): """Closes the connection pool created in __init__. """ self._pool.closeall()
class Database: def __init__(self, host, port, dbname, dbuser, dbpass, minconn=1, maxconn=1): # Thread pool self.pool = ThreadedConnectionPool( minconn=minconn, maxconn=maxconn, host=host, database=dbname, user=dbuser, password=dbpass, port=port ) # Base connection for initialization self.conn = psycopg2.connect( host=host, database=dbname, user=dbuser, password=dbpass, port=port ) self.curs = self.conn.cursor() def initialize(self): # Initialize Database, Recreate Tables try: self.curs.execute("""CREATE TABLE Users ( user_id text, balance bigint)""") except: self.conn.rollback() self.curs.execute("""DROP TABLE Users""") self.curs.execute("""CREATE TABLE Users ( user_id text, balance bigint)""") self.conn.commit() try: self.curs.execute("""CREATE TABLE Stock ( stock_id text, user_id text, amount bigint)""") except: self.conn.rollback() self.curs.execute("""DROP TABLE Stock""") self.curs.execute("""CREATE TABLE Stock ( stock_id text, user_id text, amount bigint)""") self.conn.commit() try: self.curs.execute("""CREATE TABLE PendingTrans ( type text, user_id text, stock_id text, amount bigint, timestamp bigint)""") except: self.conn.rollback() self.curs.execute("""DROP TABLE PendingTrans""") self.curs.execute("""CREATE TABLE PendingTrans ( type text, user_id text, stock_id text, amount bigint, timestamp bigint)""") self.conn.commit() try: self.curs.execute("""CREATE TABLE Trigger ( type text, user_id text, stock_id text, amount bigint, trigger bigint)""") except: self.conn.rollback() self.curs.execute("""DROP TABLE Trigger""") self.curs.execute("""CREATE TABLE Trigger ( type text, user_id text, stock_id text, amount bigint, trigger bigint)""") self.conn.commit() print "DB Initialized" # Return a Database Connection from the pool def get_connection(self): connection = self.pool.getconn() cursor = connection.cursor() return connection, cursor def close_connection(self, connection): self.pool.putconn(connection) # call like: select_record("Users", "id,balance", "id='jim' AND balance=200") def select_record(self, values, table, constraints): connection, cursor = self.get_connection() try: command = """SELECT %s FROM %s WHERE %s""" % (values, table, constraints) cursor.execute(command) connection.commit() except Exception as e: print 'PG Select error - ' + str(e) result = cursor.fetchall() self.close_connection(connection) # Format to always return a tuple of the single record, with each value. if len(result) > 1: print 'PG Select returned more than one value.' return (None,None) elif len(result) == 0: return (None,None) else: return result[0] def filter_records(self, values, table, constraints): connection, cursor = self.get_connection() try: command = """SELECT %s FROM %s WHERE %s""" % (values, table, constraints) cursor.execute(command) connection.commit() except Exception as e: print 'PG Select error - ' + str(e) result = cursor.fetchall() self.close_connection(connection) # Return array of tuples return result def insert_record(self, table, columns, values): connection, cursor = self.get_connection() try: command = """INSERT INTO %s (%s) VALUES (%s)""" % (table, columns, values) cursor.execute(command) connection.commit() except Exception as e: print 'PG Insert error - ' + str(e) self.close_connection(connection) def update_record(self, table, values, constraints): connection, cursor = self.get_connection() try: command = """UPDATE %s SET %s WHERE %s""" % (table, values, constraints) cursor.execute(command) connection.commit() except Exception as e: print 'PG Update error %s \n table=%s values=%s constraints=%s command=%s' % (str(e), table, values, constraints, command) self.close_connection(connection) def delete_record(self, table, constraints): connection, cursor = self.get_connection() try: command = """DELETE FROM %s WHERE %s""" % (table, constraints) cursor.execute(command) connection.commit() except Exception as e: print 'PG Delete error - ' + str(e) self.close_connection(connection)
class DBUtils: __singleton = None __lock = Lock() QUOTES_UPSERT_SQL = f""" INSERT INTO quotes (NAME, CODE, OPENING_PRICE, LATEST_PRICE, QUOTE_CHANGE, CHANGE, VOLUME, TURNOVER, AMPLITUDE, TURNOVER_RATE, "PE_ratio", VOLUME_RATIO, MAX_PRICE, MIN_PRICE, CLOSING_PRICE, "PB_ratio", MARKET, TIME) VALUES ({','.join(['%s'] * 18)}) ON CONFLICT ON CONSTRAINT quotes_pkey DO UPDATE SET latest_price = excluded.latest_price, change = excluded.change, volume = excluded.volume, turnover = excluded.turnover, amplitude = excluded.amplitude, turnover_rate = excluded.turnover_rate, volume_ratio = excluded.volume_ratio, max_price = excluded.max_price, min_price = excluded.min_price """ COMPANIES_UPSERT_SQL = f""" INSERT INTO companies (code, name, intro, manage, ssrq, clrq, fxl, fxfy, mgfxj, fxzsz, srkpj, srspj, srhsl, srzgj, djzql, wxpszql, mjzjje, zczb) VALUES ({','.join(['%s'] * 18)}) ON CONFLICT ON CONSTRAINT companies_pkey DO UPDATE SET fxl = excluded.fxl, fxfy = excluded.fxfy, mgfxj = excluded.mgfxj, fxzsz = excluded.fxzsz, djzql = excluded.djzql, wxpszql = excluded.wxpszql, zczb = excluded.zczb """ CODES_QUERY_SQL = """ SELECT DISTINCT code FROM quotes """ MANAGE_QUERY_SQL = """ SELECT manage FROM companies order by code """ POS_VEC_UPDATE_SQL = """ UPDATE companies SET pos_vec = %s WHERE code = %s """ CODES_MANAGE_QUERY_SQL = """ select code, manage from companies """ POS_VEC_QUERY_SQL = """ SELECT pos_vec FROM companies WHERE code = %s """ NEED_UPDADE_CODE_SQL = """ SELECT DISTINCT code FROM quotes EXCEPT SELECT code FROM companies """ MAIN_TARGET_INSERT_SQL = f""" INSERT INTO main_target (code, date, jbmgsy, kfmgsy, xsmgsy, mgjzc, mggjj, mgwfply, mgjyxjl, yyzsr, mlr, gsjlr, kfjlr, yyzsrtbzz, gsjlrtbzz, kfjlrtbzz, yyzsrgdhbzz, gsjlrgdhbzz, kfjlrgdhbzz, jqjzcsyl, tbjzcsyl, tbzzcsyl, mll, jll, sjsl, yskyysr, jyxjlyysr, xsxjlyysr, zzczzl, yszkzzts, chzzts, zcfzl, ldzczfz, ldbl, sdbl) VALUES ({','.join(['%s'] * 35)}) ON CONFLICT DO NOTHING """ def __init__(self, database_config): self.database_config = database_config pool_config = self.database_config['pool'] conn_config = self.database_config['conn'] self.conn_pool = ThreadedConnectionPool(minconn=pool_config['min'], maxconn=pool_config['max'], **conn_config) def closeall(self): self.conn_pool.closeall() def upsert_quotes(self, quotes): """更新行情""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: try: cur.executemany(self.QUOTES_UPSERT_SQL, quotes) conn.commit() except errors.UniqueViolation as e: print("当前记录已存在") except errors.NotNullViolation as e: print("违反非空约束") finally: self.conn_pool.putconn(conn) def upsert_company(self, company): """添加公司信息""" conn = self.conn_pool.getconn() try: value = [ company['code'], company['name'], company['intro'], company['manage'], company['ssrq'], company['clrq'], company['fxl'], company['fxfy'], company['mgfxj'], company['fxzsz'], company['srkpj'], company['srspj'], company['srhsl'], company['srzgj'], company['djzql'], company['wxpszql'], company['mjzjje'], company['zczb'] ] with conn.cursor() as cur: cur.execute(self.COMPANIES_UPSERT_SQL, value) conn.commit() except errors.UniqueViolation as e: print("当前记录已存在") except errors.NotNullViolation as e: print("违反非空约束") finally: self.conn_pool.putconn(conn) def get_all_codes(self): """获得所有公司代码""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.CODES_QUERY_SQL) codes = cur.fetchall() return codes except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def get_all_manage(self): """获得所有公司经营范围""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.MANAGE_QUERY_SQL) manage = cur.fetchall() return manage except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def get_all_codes_manage(self): """获取公司所有代码和经营范围""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.CODES_MANAGE_QUERY_SQL) result = cur.fetchall() return result except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def update_company_pos_vec(self, code, pos_vec): """设置公司在行业语义空间中的位置""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.POS_VEC_UPDATE_SQL, (pos_vec, code)) conn.commit() except Exception as e: print(e) finally: self.conn_pool.putconn(conn) def get_pos_vec(self, code): """获得语义空间位置向量""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.POS_VEC_QUERY_SQL, (code, )) pos_vec = cur.fetchone() return pos_vec finally: self.conn_pool.putconn(conn) def get_need_update_codes(self): """获得所有需要更新的公司代码""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.execute(self.NEED_UPDADE_CODE_SQL) result = cur.fetchall() return result finally: self.conn_pool.putconn(conn) def insert_main_target(self, targets): """保存主要指标""" conn = self.conn_pool.getconn() try: with conn.cursor() as cur: cur.executemany(self.MAIN_TARGET_INSERT_SQL, targets) conn.commit() except errors.UniqueViolation as e: print("当前记录已存在") except errors.NotNullViolation as e: print("违反非空约束") finally: self.conn_pool.putconn(conn) @staticmethod def init(database_config): DBUtils.__lock.acquire() try: if DBUtils.__singleton is None: DBUtils.__singleton = DBUtils(database_config) except Exception: traceback.print_exc() finally: DBUtils.__lock.release() return DBUtils.__singleton
class Database(rigor.database.Database): """ Container for a database connection pool """ def __init__(self, database): super(Database, self).__init__(database) register_type(psycopg2.extensions.UNICODE) register_uuid() dsn = Database.build_dsn(database) self._pool = ThreadedConnectionPool(config.get('database', 'min_database_connections'), config.get('database', 'max_database_connections'), dsn) @staticmethod def build_dsn(database): """ Builds the database connection string from config values """ dsn = "dbname='{0}' host='{1}'".format(database, config.get('database', 'host')) try: ssl = config.getboolean('database', 'ssl') if ssl: dsn += " sslmode='require'" except ConfigParser.Error: pass try: username = config.get('database', 'username') dsn += " user='******'".format(username) except ConfigParser.Error: pass try: password = config.get('database', 'password') dsn += " password='******'".format(password) except ConfigParser.Error: pass return dsn @staticmethod @template def create(name): """ Creates a new database with the given name """ return "CREATE DATABASE {0};".format(name) @staticmethod @template def drop(name): """ Drops the database with the given name """ return "DROP DATABASE {0};".format(name) @staticmethod @template def clone(source, destination): """ Copies the source database to a new destination database. This may fail if the source database is in active use. """ return "CREATE DATABASE {0} WITH TEMPLATE {1};".format(destination, source) @contextmanager def get_cursor(self, commit=True): """ Gets a cursor from a connection in the pool """ connection = self._pool.getconn() cursor = connection.cursor(cursor_factory=RigorCursor) try: yield cursor except psycopg2.IntegrityError as error: exc_info = sys.exc_info() self.rollback(cursor) raise rigor.database.IntegrityError, exc_info[1], exc_info[2] except psycopg2.DatabaseError as error: exc_info = sys.exc_info() self.rollback(cursor) raise rigor.database.DatabaseError, exc_info[1], exc_info[2] except: exc_info = sys.exc_info() self.rollback(cursor) raise exc_info[0], exc_info[1], exc_info[2] else: if commit: self.commit(cursor) else: self.rollback(cursor) def _close_cursor(self, cursor): """ Closes a cursor and releases the connection to the pool """ cursor.close() self._pool.putconn(cursor.connection) def commit(self, cursor): """ Commits the transaction, then closes the cursor """ cursor.connection.commit() self._close_cursor(cursor) def rollback(self, cursor): """ Rolls back the transaction, then closes the cursor """ cursor.connection.rollback() self._close_cursor(cursor) def __del__(self): self._pool.closeall()
class PostgresqlUtil(DBInterface): def __init__(self): super(PostgresqlUtil, self).__init__() self.conn_pool = ThreadedConnectionPool( minconn=1, maxconn=32, database=config.POSTGRESQL_DATABASE, user=config.POSTGRESQL_USER_NAME, password=config.POSTGRESQL_PASSWORD, host=config.POSTGRESQL_HOST, port=config.POSTGRESQL_PORT) def get_conn(self): conn = self.conn_pool.getconn() return conn def put_conn(self, conn): self.conn_pool.putconn(conn) def close(self): self.conn_pool.closeall() # 插入 def insert(self, table, data, **kwargs): insert_sql = generate_sql_string("INSERT", table, data) conn = self.get_conn() cur = conn.cursor() try: cur.execute(insert_sql) except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn) # 查找 def select(self, table, condition, **kwargs): select_sql = generate_sql_string("SELECT", table, dict_condition=condition) conn = self.get_conn() cur = conn.cursor() res = [] try: cur.execute(select_sql) res = cur.fetchall() except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn) return res # 删除 def delete(self, table, condition, **kwargs): delete_sql = generate_sql_string("DELETE", table, dict_condition=condition) conn = self.get_conn() cur = conn.cursor() try: cur.execute(delete_sql) except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn) # 修改 def update(self, table, condition, data, upsert=False, **kwargs): item = self.select(table, condition) if len(item) == 0 and upsert is False: raise ValueError() elif len(item) == 0 and upsert is True: self.insert(table, data) else: # len(item) > 0 update_sql = generate_sql_string("UPDATE", table, data, condition) conn = self.get_conn() cur = conn.cursor() try: cur.execute(update_sql) except Exception as e: conn.rollback() raise e finally: cur.close() conn.commit() self.put_conn(conn)
class Scraper(object): ''' Parent class of all scrapers ''' def __init__(self): self.NOW = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.connect_db() proxies = self.load_proxies() headers = self.load_user_headers() self.proxy_pool = cycle(proxies) self.header_pool = cycle(headers) # + - - - PROXIES & HEADERS - - - + def load_proxies(self): ''' Load proxies from csv file and return a set of proxies''' proxies = set() try: df = pd.read_csv("./src/utils/proxy_files/proxies.csv") for i, r in df.iterrows(): proxy = ':'.join([r['IP Address'], str(r['Port'])[:-2]]) proxies.add(proxy) except Exception as e: print(e) return proxies def load_user_headers(self): ''' Load headers from csv file and return a set of headers''' headers = set() df = pd.read_csv("./src/utils/proxy_files/user_agents.csv") for i, r in df.iterrows(): headers.add(r['User agent']) return headers def get_proxies(self): ''' Get the next proxy in proxy pool''' proxy = next(self.proxy_pool) return {"http": proxy, "https": proxy} def get_headers(self): ''' Get the next header in header pool''' headers = next(self.header_pool) return {"User-Agent": headers} def reset_proxy_pool(self): ''' Download new proxies, save to csv and load csv''' download_free_proxies() proxies = self.load_proxies() self.proxy_pool = cycle(proxies) # + - - - DATABASE - - - + def connect_db(self): ''' Connect to db on cloud ''' host_name = "bitko.czflnkl5jbgy.ap-southeast-2.rds.amazonaws.com" self._connpool = ThreadedConnectionPool(1, 2, database="bitko", user="******", password="******", host=host_name, port="5432") @contextmanager def cursor(self): ''' Get a cursor from the conn pool ''' # Get available connection from pool conn = self._connpool.getconn() conn.autocommit = True try: # Return a generator cursor() created on the fly yield conn.cursor() finally: # Return the connection back to connection pool self._connpool.putconn(conn) def query_list(self, sql, values): with self.cursor() as cur: cur.execute(sql, values) results = [i for i in cur.fetchall()] return results def query_one(self, sql, values): with self.cursor() as cur: cur.execute(sql, values) result = cur.fetchone() return result def execute(self, sql, values): ''' Execute sql command''' with self.cursor() as cur: cur.execute(sql, values)
class PostGISCatalog(Catalog): def __init__(self, table="footprints", database_url=os.getenv("DATABASE_URL"), geometry_column="geom"): if database_url is None: raise Exception("Database URL must be provided.") urlparse.uses_netloc.append('postgis') urlparse.uses_netloc.append('postgres') url = urlparse.urlparse(database_url) self._pool = ThreadedConnectionPool(1, 16, database=url.path[1:], user=url.username, password=url.password, host=url.hostname, port=url.port) self._log = logging.getLogger(__name__) self.table = table self.geometry_column = geometry_column def _candidates(self, bounds, resolution, min_zoom, max_zoom, include_geometries=False): self._log.info("Resolution: %s; zoom range: %d-%d", resolution, min_zoom, max_zoom) # TODO get sources in native CRS of the target query = """ WITH bbox AS ( SELECT ST_SetSRID( 'BOX(%(minx)s %(miny)s, %(maxx)s %(maxy)s)'::box2d, 4326) geom ), sources AS ( SELECT url, source, resolution, coalesce(bands, '{{}}'::jsonb) bands, coalesce(meta, '{{}}'::jsonb) meta, coalesce(recipes, '{{}}'::jsonb) recipes, acquired_at, priority, ST_Multi(footprints.geom) geom, filename, min_zoom, max_zoom FROM {table} footprints JOIN bbox ON footprints.geom && bbox.geom WHERE numrange(min_zoom, max_zoom, '[]') && numrange(%(min_zoom)s, %(max_zoom)s, '[]') AND footprints.enabled = true ) SELECT url, source, resolution, bands, meta, recipes, acquired_at, null band, priority, null coverage, CASE WHEN {include_geometries} THEN ST_AsGeoJSON(geom) ELSE 'null' END geom, filename, min_zoom, max_zoom FROM sources """.format(table=self.table, geometry_column=self.geometry_column, include_geometries=bool(include_geometries)) if bounds.crs == WGS84_CRS: left, bottom, right, top = bounds.bounds else: left, bottom, right, top = warp.transform_bounds( bounds.crs, WGS84_CRS, *bounds.bounds) connection = self._pool.getconn() try: with connection as conn, conn.cursor() as cur: cur.execute( query, { "minx": left if left != Infinity else -180, "miny": bottom if bottom != Infinity else -90, "maxx": right if right != Infinity else 180, "maxy": top if top != Infinity else 90, "min_zoom": min_zoom, "max_zoom": max_zoom, "resolution": min(resolution), }) for record in cur: yield Source(*record[:-4], geom=json.loads(record[-4]), filename=record[-3], min_zoom=record[-2], max_zoom=record[-1]) except Exception as e: self._log.error(e) finally: self._pool.putconn(connection) def _fill_bounds(self, bounds, resolution, include_geometries=False): zoom = get_zoom(max(resolution)) query = """ WITH RECURSIVE bbox AS ( SELECT ST_SetSRID( 'BOX(%(minx)s %(miny)s, %(maxx)s %(maxy)s)'::box2d, 4326) geom ), date_range AS ( SELECT COALESCE(min(acquired_at), '1970-01-01') min, COALESCE(max(acquired_at), '1970-01-01') max, age(COALESCE(max(acquired_at), '1970-01-01'), COALESCE(min(acquired_at), '1970-01-01')) "interval" FROM {table} ), sources AS ( SELECT * FROM ( SELECT 1 iterations, ARRAY[url] urls, ARRAY[source] sources, ARRAY[resolution] resolutions, ARRAY[coalesce(bands, '{{}}'::jsonb)] bands, ARRAY[coalesce(meta, '{{}}'::jsonb)] metas, ARRAY[coalesce(recipes, '{{}}'::jsonb)] recipes, ARRAY[acquired_at] acquisition_dates, ARRAY[priority] priorities, ARRAY[ST_Area(ST_Intersection(bbox.geom, footprints.geom)) / ST_Area(bbox.geom)] coverages, ARRAY[ST_Multi(footprints.geom)] geometries, ST_Multi(footprints.geom) geom, ST_Difference(bbox.geom, footprints.geom) uncovered FROM date_range, {table} footprints JOIN bbox ON footprints.geom && bbox.geom WHERE %(zoom)s BETWEEN min_zoom AND max_zoom AND footprints.enabled = true ORDER BY 10 * coalesce(footprints.priority, 0.5) * .1 * (1 - (extract( EPOCH FROM (current_timestamp - COALESCE( acquired_at, '2000-01-01'))) / extract( EPOCH FROM (current_timestamp - date_range.min)))) * 50 * -- de-prioritize over-zoomed sources CASE WHEN %(resolution)s / footprints.resolution >= 1 THEN 1 ELSE 1 / footprints.resolution END * ST_Area( ST_Intersection(bbox.geom, footprints.geom)) / ST_Area(bbox.geom) DESC LIMIT 1 ) AS _ UNION ALL SELECT * FROM ( SELECT sources.iterations + 1, sources.urls || url urls, sources.sources || source sources, sources.resolutions || resolution resolutions, sources.bands || coalesce( footprints.bands, '{{}}'::jsonb) bands, sources.metas || coalesce(meta, '{{}}'::jsonb) metas, sources.recipes || coalesce( footprints.recipes, '{{}}'::jsonb) recipes, sources.acquisition_dates || footprints.acquired_at acquisition_dates, sources.priorities || footprints.priority priorities, sources.coverages || ST_Area( ST_Intersection(sources.uncovered, footprints.geom)) / ST_Area(bbox.geom) coverages, sources.geometries || footprints.geom, ST_Collect(sources.geom, footprints.geom) geom, ST_Difference(sources.uncovered, footprints.geom) uncovered FROM bbox, date_range, {table} footprints -- use proper intersection to prevent voids from irregular -- footprints JOIN sources ON ST_Intersects( footprints.geom, sources.uncovered) WHERE NOT (footprints.url = ANY(sources.urls)) AND %(zoom)s BETWEEN min_zoom AND max_zoom AND footprints.enabled = true ORDER BY 10 * coalesce(footprints.priority, 0.5) * .1 * (1 - (extract( EPOCH FROM (current_timestamp - COALESCE( acquired_at, '2000-01-01'))) / extract( EPOCH FROM (current_timestamp - date_range.min)))) * 50 * -- de-prioritize over-zoomed sources CASE WHEN %(resolution)s / footprints.resolution >= 1 THEN 1 ELSE 1 / footprints.resolution END * ST_Area( ST_Intersection(sources.uncovered, footprints.geom)) / ST_Area(bbox.geom) DESC LIMIT 1 ) AS _ ), candidates AS ( SELECT * FROM sources ORDER BY iterations DESC LIMIT 1 ), candidate_rows AS ( SELECT unnest(urls) url, unnest(sources) source, unnest(resolutions) resolution, unnest(bands) bands, unnest(metas) meta, unnest(recipes) recipes, unnest(acquisition_dates) acquired_at, unnest(priorities) priority, unnest(coverages) coverage, unnest(geometries) geom FROM candidates ) SELECT url, source, resolution, bands, meta, recipes, acquired_at, null band, priority, coverage, CASE WHEN {include_geometries} THEN ST_AsGeoJSON(geom) ELSE 'null' END geom FROM candidate_rows """.format(table=self.table, geometry_column=self.geometry_column, include_geometries=bool(include_geometries)) if bounds.crs == WGS84_CRS: left, bottom, right, top = bounds.bounds else: left, bottom, right, top = warp.transform_bounds( bounds.crs, WGS84_CRS, *bounds.bounds) connection = self._pool.getconn() try: with connection as conn, conn.cursor() as cur: cur.execute( query, { "minx": left if left != Infinity else -180, "miny": bottom if bottom != Infinity else -90, "maxx": right if right != Infinity else 180, "maxy": top if top != Infinity else 90, "zoom": zoom, "resolution": min(resolution), }) for record in cur: yield Source(*record[:-1], geom=json.loads(record[-1])) except Exception as e: self._log.error(e) finally: self._pool.putconn(connection) def get_sources(self, bounds, resolution, min_zoom=None, max_zoom=None, include_geometries=False): if min_zoom is None or max_zoom is None: return self._fill_bounds(bounds, resolution, include_geometries=include_geometries) return self._candidates(bounds, resolution, min_zoom, max_zoom, include_geometries=include_geometries)
class PgPool: def __init__(self): logger.debug('initializing postgres threaded pool') self.host, self.port = None, None self.database, self.pool = None, None self.user, self.passwd = None, None self.pool = None logger.debug('Server Addr: {host}:{port}; Database: {db}; User: {user}; Password: {passwd}'.format( host=self.host, port=self.port, db=self.database, user=self.user, passwd=self.passwd )) def create_pool(self, conn_dict, limits): """ Create a connection pool :param conn_dict: connection params dictionary :type conn_dict: dict """ if conn_dict["Host"] is None: self.host = 'localhost' else: self.host = conn_dict["Host"] if conn_dict["Port"] is None: self.port = '5432' else: self.port = conn_dict["Port"] self.database = conn_dict["Database"] self.user = conn_dict["User"] self.passwd = conn_dict["Password"] conn_params = "host='{host}' dbname='{db}' user='******' password='******' port='{port}'".format( host=self.host, db=self.database, user=self.user, passwd=self.passwd, port=self.port ) try: logger.debug('creating pool') self.pool = ThreadedConnectionPool(int(limits["Min"]), int(limits["Max"]), conn_params) except Exception as e: logger.exception(e.message) def get_conn(self): """ Get a connection from pool and return connection and cursor :return: conn, cursor """ logger.debug('getting connection from pool') try: conn = self.pool.getconn() cursor = conn.cursor(cursor_factory=psycopg2.extras.DictCursor) return conn, cursor except Exception as e: logger.exception(e.message) return None, None @staticmethod def execute_query(cursor, query, params): """ Execute a query on database :param cursor: cursor object :param query: database query :type query: str :param params: query parameters :type params: tuple :return: query results or bool """ logger.info('executing query') logger.debug('Cursor: {cursor}, Query: {query}'.format( cursor=cursor, query=query)) try: if query.split()[0].lower() == 'select': cursor.execute(query, params) return cursor.fetchall() else: return cursor.execute(query, params) except Exception as e: logger.exception(e.message) return False # commit changes to db permanently @staticmethod def commit_changes(conn): """ Commit changes to the databse permanently :param conn: connection object :return: bool """ logger.debug('commiting changes to database') try: return conn.commit() except Exception as e: logger.exception(e.message) return False def put_conn(self, conn): """ Put connection back to the pool :param conn: connection object :return: bool """ logger.debug('putting connection {conn} back to pool'.format(conn=conn)) try: return self.pool.putconn(conn) except Exception as e: logger.exception(e.message) return False def close_pool(self): """ Closes connection pool :return: bool """ logger.debug('closing connections pool') try: return self.pool.closeall() except Exception as e: logger.exception(e.message) return False
class PgUtil(): # __metaclass__ = Singleton def __init__(self): self.conn_pool = ThreadedConnectionPool(minconn=4, maxconn=100, database='coda', user='******', password='******', host='127.0.0.1', port=5432) def get_conn(self): conn = self.conn_pool.getconn() return conn def put_conn(self, conn): self.conn_pool.putconn(conn) def execute_insert_sql(self, sql, values): conn = self.get_conn() cur = conn.cursor(cursor_factory=DictCursor) cur.execute(sql, values) cur.close() conn.commit() self.put_conn(conn) def query_all_sql(self, sql): conn = self.get_conn() cur = conn.cursor(cursor_factory=RealDictCursor) cur.execute(sql) result_list = [] for row in cur.fetchall(): result_list.append(row) conn.commit() self.put_conn(conn) return result_list def query_one_sql(self, sql): conn = self.get_conn() cur = conn.cursor(cursor_factory=DictCursor) cur.execute(sql) row = cur.fetchone() cur.close() conn.commit() self.put_conn(conn) return dict(row) if row else {} def execute_sql(self, sql): conn = self.get_conn() cur = conn.cursor(cursor_factory=DictCursor) cur.execute(sql) cur.close() conn.commit() self.put_conn(conn) def select_sql(self, sql, values): conn = self.get_conn() cur = conn.cursor(cursor_factory=DictCursor) cur.execute(sql, values) res = cur.fetchone() cur.close() conn.commit() self.put_conn(conn) return res def select_all_sql(self, sql, values): conn = self.get_conn() cur = conn.cursor(cursor_factory=DictCursor) cur.execute(sql, values) res = cur.fetchall() cur.close() conn.commit() self.put_conn(conn) return res def execute_update_sql(self, sql, values): conn = self.get_conn() cur = conn.cursor(cursor_factory=DictCursor) cur.execute(sql, values) cur.close() conn.commit() self.put_conn(conn)
class Pool: _pools = {} def __init__(self, cfg): db_uri = cfg.get("db_uri", DEFAULT_DB_URI) self.cfg = cfg uri = urlparse(db_uri) dbname = uri.path[1:] self.flavor = uri.scheme self.pg_schema = None if self.flavor == "sqlite": self.conn_args = [dbname] self.conn_kwargs = { "check_same_thread": False, "detect_types": sqlite3.PARSE_DECLTYPES, "isolation_level": "DEFERRED", } sqlite3.register_converter("JSONB", json.loads) sqlite3.register_converter("INTEGER[]", convert_array(int)) sqlite3.register_converter("VARCHAR[]", convert_array(str)) sqlite3.register_converter("FLOAT[]", convert_array(float)) sqlite3.register_converter( "BOOL[]", convert_array(lambda x: x == "True") ) elif self.flavor == "postgresql": self.pg_schema = uri.fragment if psycopg2 is None: raise ImportError( 'Cannot connect to "%s" without psycopg2 package ' "installed" % db_uri ) con_info = "dbname='%s' " % dbname if uri.hostname: con_info += "host='%s' " % uri.hostname if uri.username: con_info += "user='******' " % uri.username if uri.password: con_info += "password='******' " % uri.password if uri.port: con_info += "port='%s' " % uri.port self.pg_pool = ThreadedConnectionPool( cfg.get("pg_min_pool_size", 1), cfg.get("pg_max_pool_size", 10), con_info, ) elif self.flavor == "crdb": if psycopg2 is None: raise ImportError( 'Cannot connect to "%s" without psycopg2 package ' "installed" % db_uri ) # transform crdb into postgreql in uri scheme to please # psycopg2 uri_parts = list(uri) uri_parts[0] = "postgresql" self.db_uri = urlunparse(uri_parts) else: raise ValueError( 'Unsupported scheme "%s" in uri "%s"' % (uri.scheme, uri) ) def enter(self): if self.flavor == "sqlite": connection = sqlite3.connect(*self.conn_args, **self.conn_kwargs) connection.text_factory = str connection.execute("PRAGMA foreign_keys=ON") connection.execute("PRAGMA journal_mode=wal") elif self.flavor == "crdb": connection = psycopg2.connect(self.db_uri) elif self.flavor == "postgresql": connection = self.pg_pool.getconn() if self.pg_schema: qr = "SET search_path TO %s" % self.pg_schema connection.cursor().execute(qr) else: raise ValueError('Unexpected flavor "%s"' % self.flavor) return connection def leave(self, connection, exc=None): if exc: logger.debug("ROLLBACK") connection.rollback() else: logger.debug("COMMIT") connection.commit() if self.flavor == "postgresql": self.pg_pool.putconn(connection) else: connection.close() @classmethod def disconnect(cls): for pool in cls._pools.values(): if pool.flavor == "postgresql": pool.pg_pool.closeall() cls.clear() @classmethod def clear(cls): cls._pools = {} @classmethod def get_pool(cls, cfg): db_uri = cfg.get("db_uri", DEFAULT_DB_URI) pool = cls._pools.get(db_uri) if pool: # Return existing pool for current db if any return pool pool = Pool(cfg) cls._pools[db_uri] = pool return pool
class PgPool: def __init__(self): self.host, self.port = None, None self.database, self.pool = None, None self.user, self.passwd = None, None self.pool = None def create_pool(self, conn_dict): """ Create a connection pool :param conn_dict: connection params dictionary :type conn_dict: dict """ if conn_dict["Host"] is None: self.host = 'localhost' else: self.host = conn_dict["Host"] if conn_dict["Port"] is None: self.port = '5432' else: self.port = conn_dict["Port"] self.database = conn_dict["Database"] self.user = conn_dict["User"] self.passwd = conn_dict["Password"] conn_params = "host='{host}' dbname='{db}' user='******' password='******' port='{port}'".format( host=self.host, db=self.database, user=self.user, passwd=self.passwd, port=self.port) try: self.pool = ThreadedConnectionPool(1, 50, conn_params) except Exception as e: print(e.message) def get_conn(self): """ Get a connection from pool and return connection and cursor :return: conn, cursor """ try: conn = self.pool.getconn() cursor = conn.cursor(cursor_factory=psycopg2.extras.DictCursor) return conn, cursor except Exception as e: print(e.message) return None, None @staticmethod def execute_query(cursor, query, params): """ Execute a query on database :param cursor: cursor object :param query: database query :type query: str :param params: query parameters :type params: tuple :return: query results or bool """ try: if query.split()[0].lower() == 'select': cursor.execute(query, params) return cursor.fetchall() else: return cursor.execute(query, params) except Exception as e: print(e.message) return False # commit changes to db permanently @staticmethod def commit_changes(conn): """ Commit changes to the database permanently :param conn: connection object :return: bool """ try: return conn.commit() except Exception as e: print(e.message) return False def put_conn(self, conn): """ Put connection back to the pool :param conn: connection object :return: bool """ try: return self.pool.putconn(conn) except Exception as e: print(e.message) return False def close_pool(self): """ Closes connection pool :return: bool """ try: return self.pool.closeall() except Exception as e: print(e.message) return False
class PgConnManager: """ Manage postgresql database connections. """ # dict to keep track of PgConnManager instances _instances = {} class DatabaseError(psycopg2.Error): """Raised when database connection error occurs.""" pass def __new__(self, *kargs, **kwargs): db_opts = kargs[0] if "logger" in kwargs: logger = kwargs.pop("logger") else: logger = logging.getLogger("sjutils.pgconnmanager") # We can either accept 'database' or 'dbname' as an input if "database" in db_opts and "dbname" not in db_opts: db_opts["dbname"] = db_opts["database"] # This is ugly but since dict types cannot be used # as keys in another dict, we need to transform it. # This transformation was chosen as it is human # readable, it could be changed to a more optimised one. db_str = ( "host=%(host)s port=%(port)s user=%(user)s password=%(password)s dbname=%(dbname)s" % db_opts) if db_str not in self._instances: self._instances[db_str] = super(PgConnManager, self).__new__(self) self._instances[db_str].lock = threading.Lock() self._instances[db_str].log = logger return self._instances[db_str] def __init__(self, db_opts): self.__params__ = db_opts if "isolation_level" not in self.__params__: self.__params__[ "isolation_level"] = psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED if not hasattr(self, "__conn_pool__"): self.__conn_pool__ = None def decorate(self, func, *args, **kw): try: ctx_list = self.connect() try: try: ret = func(self, ctx_list, *args, **kw) except psycopg2.OperationalError as _error: # We got a database disconnection not catched by user, wiping all connection because # psycopg2 does not fill correctly the database connection 'closed' attribute in case of disconnection self.release_all(ctx_list, close=True) raise # Connection(s) wasn't released by user, so we have to release it/them self.release_all(ctx_list, rollback=True) return ret except psycopg2.Error as _error: self.release_all(ctx_list, rollback=True) raise except Exception: self.release_all(ctx_list, rollback=True) raise except psycopg2.Error as _error: # We do not want our users to have to 'import psycopg2' to # handle the module's underlying database errors _, value, traceback = sys.exc_info() raise self.DatabaseError(value).with_traceback(traceback) def _new_ctx(self, mark=None): """Create a new context object.""" if not mark: mark = "none" ret = {"conn": None, "cursor": None, "mark": mark} ret["conn"] = self.__conn_pool__.getconn() self.log.debug( "ctx:(" + mark + ", " + str(id(ret)) + ") Creating context to database: %(dbname)s as %(user)s" % self.__params__) return ret def connect(self, ctx_list=None, mark=None): """Connect to database.""" self.log.debug("Connecting to database: %(dbname)s as %(user)s" % self.__params__) try: self.lock.acquire_lock() try: if not self.__conn_pool__: connector = ( "host=%(host)s port=%(port)s user=%(user)s password=%(password)s dbname=%(dbname)s" % self.__params__) self.__conn_pool__ = ThreadedConnectionPool( 1, 20, connector) finally: self.lock.release_lock() if not ctx_list: ctx_list = [] ctx_list.append(self._new_ctx(mark)) self.log.debug( "ctx:(%(mark)s, %(ctxid)s) Connected to database: %(dbname)s as %(user)s" % { "mark": ctx_list[-1]["mark"], "ctxid": str(id(ctx_list[-1])), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) return ctx_list except psycopg2.Error as _error: # We do not want our users to have to 'import psycopg2' to # handle the module's underlying database errors _, value, traceback = sys.exc_info() raise self.DatabaseError(value).with_traceback(traceback) def execute(self, ctx, query, options=None): """Execute an SQL query.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) Executing query on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) try: if not ctx["conn"]: ctx["conn"] = self.__conn_pool__.getconn() if ("isolation_level" in self.__params__ and self.__params__["isolation_level"] != ctx["conn"].isolation_level): ctx["conn"].set_isolation_level( self.__params__["isolation_level"]) if not ctx["cursor"]: ctx["cursor"] = ctx["conn"].cursor( cursor_factory=psycopg2.extras.DictCursor) try: if options: ctx["cursor"].execute(query, options) else: ctx["cursor"].execute(query) except psycopg2.OperationalError as _error: # We got a database disconnection, wiping connection self.release(ctx, close=True) raise except psycopg2.Error as _error: # We do not want our users to have to 'import psycopg2' to # handle the module's underlying database errors _, value, traceback = sys.exc_info() raise self.DatabaseError(value).with_traceback(traceback) def commit(self, ctx): """Commit changes to dabatase.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) Commiting changes on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) try: try: ctx["conn"].commit() except psycopg2.OperationalError as _error: # We got a database disconnection, wiping connection self.release(ctx, close=True) raise except psycopg2.Error as _error: # We do not want our users to have to 'import psycopg2' to # handle the module's underlying database errors _, value, traceback = sys.exc_info() raise self.DatabaseError(value).with_traceback(traceback) def rollback(self, ctx): """Rollback changes to database.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) Reverting changes on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) # FIXME: ctx['conn'] is None if execute failed, see bug #3085 if ctx["conn"] and not ctx["conn"].closed: ctx["conn"].rollback() def release(self, ctx, rollback=False, close=False): """Release database connection.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) Releasing connection on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) if rollback: self.rollback(ctx) ctx["cursor"] = None if ctx["conn"]: self.log.debug( "ctx:(%(mark)s, %(ctxid)s) Disposing connection on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) closeit = close or (ctx["conn"].closed > 0) self.__conn_pool__.putconn(ctx["conn"], close=closeit) ctx["conn"] = None def release_all(self, ctx_list, rollback=False, close=False): """Release all database connections from a context list.""" self.log.debug( "Releasing ALL connections on database: %(dbname)s as %(user)s" % self.__params__) for ctx in ctx_list: self.release(ctx, rollback=rollback, close=close) ctx_list = None def fetchall(self, ctx): """Return all rows of current request.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) fetchall on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) return ctx["cursor"].fetchall() def fetchmany(self, ctx, arraysize=1000): """Return @arraysize rows of current request.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) fetchmany on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) return ctx["cursor"].fetchmany(arraysize) def fetchgenerator(self, ctx): """A basic generator to iterate through the current request's result rows.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) fetchgenerator on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) while True: results = self.fetchmany(ctx) if not results: return for result in results: yield result def get_rowcount(self, ctx): """Get the cursor's rowcount attribute of the given context @ctx.""" return ctx["cursor"].rowcount def fetchone(self, ctx): """Return one row of current request.""" self.log.debug( "ctx:(%(mark)s, %(ctxid)s) fetchone on database: %(dbname)s as %(user)s" % { "mark": ctx["mark"], "ctxid": str(id(ctx)), "dbname": self.__params__["dbname"], "user": self.__params__["user"], }) return ctx["cursor"].fetchone() def set_isolation_level(self, isolation_level): """Set the isolation level.""" self.__params__["isolation_level"] = isolation_level def get_isolation_level(self): """Set the isolation level.""" return self.__params__["isolation_level"]
class Database: """Jumbo's abstract PostgreSQL client manager. Allows to manage multiple independent connections to a PostgreSQL database and to handle arbitrary transactions between clients and database. Usage: .. code-block:: python # Initialize abstract database manager database = jumbo.database.Database() # Open a connection pool to the PostgreSQL database: with database.open() as pool: # Use a connection from the pool with pool.connect(): # Execute SQL query on the database pool.send(SQL_query) # context managers ensure connections are properly returned to the pool, and that the pool is properly closed. """ def __init__(self, env_path: Optional[str] = None) -> None: """Initializes manager to handle connections to a given PostgreSQL database. Args: env_path: path where to look for the jumbo.env configuration file. If not provided, looks for configuration file in the working directory of the script invoking this constructor. Attributes: config (jumbo.config.Config): jumbo's database configuration settings. pool (psycopg2.pool): connection pool to the PostgreSQL database. Initially closed. """ # Database configuration settings self.config = Config(env_path) # initialize a placeholder connection pool self.pool = ThreadedConnectionPool(0, 0) self.pool.closed = True # keep it closed on construction # Log configuration settings logger.info(f"Jumbo Connection Manager created:\n{self.config}") @contextmanager def open(self, minconns: int = 1, maxconns: Optional[int] = None) -> None: """Context manager opening a connection pool to the PostgreSQL database. The pool is automatically closed on exit and all connections are properly handled. Example: .. code-block:: python with self.open() as pool: # do something with the pool # pool is automatically closed here Args: minconns: minimum amount of available connections in the pool, created on startup. maxconns: maximum amount of available connections supported by the pool. Defaults to minconns. """ try: # Create connection pool self.open_pool(minconns, maxconns) yield self except OperationalError as e: logger.error(f"Error while opening connection pool: {e}") finally: # Close all connections in the pool self.close_pool() def open_pool(self, minconns: int = 1, maxconns: Optional[int] = None) -> None: """Initialises and opens a connection pool to the PostgreSQL database (psycopg2.pool.ThreadedConnectionPool). 'minconn' new connections are created immediately. The connection pool will support a maximum of about 'maxconn' connections. Args: minconns: minimum amount of available connections in the pool, created on startup. maxconns: maximum amount of available connections supported by the pool. Defaults to 'minconns'. """ # If the pool hasn't been opened yet if self.pool.closed: # initialize max number of supported connections maxconns = maxconns if maxconns is not None else minconns # create a connection pool based on jumbo's configuration settings self.pool = ThreadedConnectionPool( minconns, maxconns, host=self.config.DATABASE_HOST, user=self.config.DATABASE_USERNAME, password=self.config.DATABASE_PASSWORD, port=self.config.DATABASE_PORT, dbname=self.config.DATABASE_NAME, sslmode='disable') logger.info(f"Connection pool created to PostgreSQL database: " f"{maxconns} connections available.") def close_pool(self) -> None: """Closes all connections in the pool, making it unusable by clients. """ # If the pool hasn't been closed yet if not self.pool.closed: self.pool.closeall() logger.info("All connections in the pool have been closed " "successfully.") @contextmanager def connect(self, key: int = 1) -> None: """Context manager opening a connection to the PostgreSQL database using a connection [key] from the pool. The connection is automatically closed on exit and all transactions are properly handled. Example: .. code-block:: python # check-out connection from the pool with self.connect(key): # do something with the connection e.g. # self.send(sql_query, key) # connection is automatically closed here Args: key (optional): key to identify the connection being opened. Required for proper book keeping. """ # check-out an available connection from the pool self.get_connection(key=key) try: yield self except (Exception, KeyboardInterrupt) as e: logger.error(f"Error raised during connection [{key}] " f"transactions: {e}") finally: # return the connection to the pool self.put_back_connection(key=key) def get_connection(self, key: int = 1) -> None: """Connect to a Postgres database using an available connection from pool. The connection is assigned to 'key' on checkout. Args: key (optional): key to assign to the connection being opened. Required for proper book keeping. """ # If a pool has been opened if not self.pool.closed: try: # If the specific connection hasn't been already opened if key not in self.pool._used: # Connect to PostgreSQL database self.pool.getconn(key) logger.info(f"Connection retrieved successfully: pool " f"connection [{key}] now in use.") # perform handshake self.on_connection(key) else: logger.warning(f"Pool connection [{key}] is already in " f"use by another client. Try a different " f"key.") except PoolError as error: logger.error(f"Error while retrieving connection from " f"pool:\t{error}") sys.exit() else: logger.warning(f"No pool to the PostgreSQL database: cannot " f"retrieve a connection. Try to .open() a pool.") def on_connection(self, key: int = 1) -> None: """Client-database handshaking script to perform on retrieval of a PostgreSQL connection from the pool. Args: key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ # return database information info = self.connection_info(key=key) logger.info(f"You are connected to - {info}") def put_back_connection(self, key: int = 1) -> None: """Puts back a connection in the connection pool. Args: key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ # If this specific connection is under use if key in self.pool._used: # Reset connection to neutral state self.pool._used[key].reset() # Put back connection in the pool self.pool.putconn(self.pool._used[key], key) logger.info(f"Connection returned successfully: pool connection " f"[{key}] now available again.") else: logger.warning(f"Pool connection [{key}] has never been opened: " f"cannot put it back in the pool.") def send(self, query: Union[str, sql.Composed], subs: Optional[Tuple[str, ...]] = None, cur_method: int = 0, file: Optional[IO] = None, fetch_method: int = 2, key: int = 1) -> Union[DictRow, None]: """Sends an arbitrary PostgreSQL query to the PostgreSQL database. Transactions are auto-committed on execution. Example: .. code-block:: python # A simple query with no substitutions query = 'SELECT * FROM table_name;' results = self.send(query) # A more complex query with dynamic substitutions query = 'INSERT INTO table_name (column_name, another_column_name) VALUES (%s, $s);' subs = (value, another_value) results = self.send(query, subs) Args: query: PostgreSQL command string (can be template with psycopg2 %s fields). subs (optional): tuple of values to substitute in SQL query template (cf. psycopg2 %s formatting) cur_method (optional): code to select which psycopg2 cursor execution method to use for the SQL query: 0: cursor.execute() 1: cursor.copy_expert() file (optional): file-like object to read or write to (only relevant if cur_method:1). fetch_method (optional): code to select which psycopg2 result retrieval method to use (fetch*()): 0: cur.fetchone() 2: cur.fetchall() key (optional): key of the pool connection being used in the transaction. Defaults to [1]. Returns: list of query results (if any). Can be accessed as dictionaries. """ # If this specific connection has already been opened if key in self.pool._used: try: # try running a transaction with self.pool._used[key].cursor( cursor_factory=DictCursor) as cur: # Bind arguments to query string (if present) if subs is not None: query = cur.mogrify(query, subs) else: query = cur.mogrify(query) # Execute query if cur_method == 0: cur.execute(query) elif cur_method == 1: cur.copy_expert(sql=query, file=file) # Fetch query results try: if fetch_method == 0: records = cur.fetchone() elif fetch_method == 2: records = cur.fetchall() # Handle SQL queries that don't return any results # (INSERT, UPDATE, etc...) except ProgrammingError: records = [] pass # Commit transaction self.pool._used[key].commit() # Display success message (and shorten query if too long) s_query = query if len(query) > 78: s_query = (str(query[:75]) + '...') success_msg = f"Successfully sent: {s_query} " if cur.rowcount >= 0: success_msg += f": {cur.rowcount} rows affected." logger.info(success_msg) return records # dictionaries except (Exception, Error, DatabaseError) as e: # Rollback transaction if any problem self.pool._used[key].rollback() logger.error(f"Error while sending query {query}:{e}. " f"Transaction rolled-back.") else: logger.warning(f"Pool connection [{key}] has never been opened: " f"not available for transactions.") def listen_on_channel(self, channel_name: str, key: int = 1) -> None: """Subscribes to a PostgreSQL notification channel by listening for NOTIFYs. .. code-block:: postgresql -- Command executed: LISTEN channel_name; Args: channel_name: channel on which to LISTEN. PostgreSQL database should be configured to send NOTIFYs on this channel. key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ query = "LISTEN " + channel_name + ";" self.send(query, key=key) def connection_info(self, key: int = 1) -> DictRow: """Fetches PostgreSQL database version. .. code-block:: postgresql -- Command executed: SELECT version(); Args: key (optional): key of the pool connection being used in the transaction. Defaults to [1]. Returns: query result. Contains PostgreSQL database version information. """ query = "SELECT version();" info = self.send(query, fetch_method=0, key=key) # fetchone() return info def copy_to_table(self, query: sql.Composed, file: IO, db_table: str, schema: str = None, replace: bool = True, key: int = 1) -> None: """Utility wrapper to send a SQL query to copy data to database table. Allows to replace table if it already exists in the database. .. code-block:: postgresql -- Command type expected: COPY table_name [ ( column_name [, ...] ) ] FROM STDIN [ [ WITH ] ( option [, ...] ) ] -- Ancillary command pre-executed: TRUNCATE table_name; Example: .. code-block:: python # Copy csv data from file to a table in the database query = "COPY table_name FROM STDIN WITH CSV DELIMITER '\\t'" results = self.copy_to_table(query, file="C:\\data.csv", db_table='table_name') Args: query: PostgreSQL COPY command string. file: absolute path to file-like object to read data from. db_table: the name (not schema-qualified) of an existing database table. schema (optional): schema to which ``db_table`` belongs. If ``None``, use default schema. replace (optional): replaces table contents if True. Appends data to table contents otherwise. key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ # Replace the table already existing in the database if replace: # schema-qualify table name if needed identifier = sql.Identifier(schema, db_table) if schema is not None else sql.Identifier(db_table) # pass table name dynamically to query query_tmp = sql.SQL("TRUNCATE {};").format(identifier) self.send(query_tmp, key=key) # Copy the table from file (cur_method:1 = cur.copy_expert) self.send(query, cur_method=1, file=file, key=key) def copy_df(self, df: pd.DataFrame, db_table: str, schema: str = None, replace: bool = False, key: int = 1) -> None: """Utility wrapper to efficiently copy a pandas.DataFrame to a PostgreSQL database table. This method is faster than panda's native *.to_sql()* method and exploits PostgreSQL COPY TO command. Provides a useful mean of saving results from a pandas-centred data analysis pipeline directly to the database. Args: df: dataframe to be copied. db_table: the name (not schema-qualified) of the table to write to. schema (optional): schema to which ``db_table`` belongs. If ``None``, use default schema. replace (optional): replaces table contents if True. Appends data to table contents otherwise. key (optional): key of the pool connection being used in the transaction. Defaults to [1]. """ if key in self.pool._used: try: # Create headless csv from pandas dataframe io_file = io.StringIO() df.to_csv(io_file, sep='\t', header=False, index=False) io_file.seek(0) # Quickly create a table with correct number of columns / data # types Unfortunately we will need to quickly build a # sqlalchemy engine for this hack to work replacement_method = 'replace' if replace else 'append' engine = create_engine('postgresql+psycopg2://', creator=lambda: self.pool._used[key]) df.head(0).to_sql(name=db_table, con=engine, schema=schema, if_exists=replacement_method, index=False) # But then exploit postgreSQL COPY command instead of slow # pandas .to_sql(). Note that replace is set to false in # copy_table as we want to preserve the header table created # above # schema-qualify table name if needed identifier = sql.Identifier(schema, db_table) if schema is not None else sql.Identifier(db_table) sql_copy_expert = sql.SQL( "COPY {} FROM STDIN WITH CSV DELIMITER '\t'").format( identifier) self.copy_to_table(sql_copy_expert, file=io_file, db_table=db_table, schema=schema, replace=False, key=key) logger.info(f"DataFrame copied successfully to PostgreSQL " f"table.") except (Exception, DatabaseError) as error: logger.error(f"Error while copying DataFrame to PostgreSQL " f"table: {error}") else: logger.warning(f"Pool connection [{key}] has never been opened: " f"cannot use it to copy Dataframe to database.")
class DatabaseManager: """ This class provides abstraction over underlying database. """ def __init__(self, db_name="test_db", db_pass="", host="127.0.0.1" , port="5432"): self.connection_pool = ThreadedConnectionPool(10, 50, database=db_name, user="******", \ password=db_pass, host=host, port=port) self.logger = get_logger() def __execute_query(self, query): connection = self.connection_pool.getconn() cursor = connection.cursor() self.logger.debug("Going to execute query {}".format(query)) try: cursor.execute(query) except ProgrammingError: self.logger.error("Error occurred while executing query {}".format(query)) except IntegrityError: self.logger.error("Query failed. Duplicate row for query {}".format(query)) finally: connection.commit() self.connection_pool.putconn(connection) """ Inserts multiple rows in table_name. column_headers contain tuple of table headers. rows contain the list of tuples where each tuple has values for each rows. The values in tuple are ordered according to column_headers tuple. """ def insert_batch(self, table_name, column_headers, rows): query = "INSERT INTO {} {} VALUES {}".format(table_name, '(' + ','.join(column_headers) + ')', str(rows)[1:-1]) self.__execute_query(query) """ Updates a row(uid) with new values from column_vs_value dict. """ def update(self, table_name, column_vs_value, uid): update_str = ''.join('{}={},'.format(key, val) for key, val in column_vs_value.items())[:-1] query = "UPDATE {} SET {} WHERE id = {} ".format(table_name, update_str, uid) self.__execute_query(query) """ Deletes all rows from table_name with uids. uids is a tuple. """ def delete_batch(self, table_name , uids, uid_column_name='id'): query = "DELETE from {} WHERE {} in {}".format(table_name, uid_column_name, str(uids)) self.__execute_query(query) """ Returns the dict a row by uid. """ def get_row(self, table_name, uid, uid_column_name='id'): query = "Select * from {} where {} = {}".format(table_name, uid_column_name, uid) connection = self.connection_pool.getconn() cursor = connection.cursor() cursor.execute(query) column_names = [desc[0] for desc in cursor.description] values = cursor.fetchall() result = {} if len(values) > 0: for x, y in itertools.izip(column_names, values[0]): result[x] = y self.connection_pool.putconn(connection) return result """ Returns all distinct values of column_name from table_name. """ def get_all_values_for_attr(self, table_name, column_name): query = "Select distinct {} from {}".format(column_name, table_name) connection = self.connection_pool.getconn() cursor = connection.cursor() cursor.execute(query) rows = cursor.fetchall() uids = [row[0] for row in rows] self.connection_pool.putconn(connection) return uids """ Returns all rows from table_name satisfying where_clause. The number of returned rows are limited to limit. """ def get_all_rows(self, table_name, where_clause='1=1', limit=20, order_by=None): query = "Select * from {} where {} ".format(table_name, where_clause) if order_by: query = '{} order by {} desc'.format(query, order_by) query = '{} limit {}'.format(query, limit) connection = self.connection_pool.getconn() cursor = connection.cursor() cursor.execute(query) column_names = [desc[0] for desc in cursor.description] rows = cursor.fetchall() result = [] for row in rows: result_row = {} for x, y in itertools.izip(column_names, row): result_row[x] = str(y) result.append(result_row) self.connection_pool.putconn(connection) return result """ Gets a new connection from the pool and returns the connection object. """ def get_connection(self): return self.connection_pool.getconn() """ Releases the connection back to pool. """ def release_connection(self, connection): self.connection_pool.putconn(connection)
class DB(object): connections = 0 def __init__(self, config): self.debug = True if (DEBUG and Debug_Level >= 3) else False if isinstance(config, DB): self.DataName = config.DataName self.pool = config.pool self.cnx = self.cur = None return try: self.DataName = config['datatype'] del config['datatype'] except: self.DataName = 'POSTGRESQL' if self.DataName == 'MYSQL': try: self.pool = mysql.connector.pooling.MySQLConnectionPool( **config) self.cnx = self.cur = None except mysql.connector.Error as err: # 这里需要记录操作日志 logging.debug(err.msg) self.cnx = None raise BaseError(701) # 与数据库连接异常 elif self.DataName == 'POSTGRESQL': try: self.pool = ThreadedConnectionPool(**config) except: raise BaseError(701) # 与数据库连接异常 elif self.DataName == 'ORACLE': try: if config['NLS_LANG']: os.environ['NLS_LANG'] = config['NLS_LANG'] del config['NLS_LANG'] except: pass try: self.pool = cx_Oracle.SessionPool(**config) except: raise BaseError(701) # 与数据库连接异常 #恢复删除的配置 config['datatype'] = self.DataName def setDebug(self, debug): self.debug = debug def clone(self): db = DB(self) return db def open(self, auto=False): try: DB.connections += 1 # print("===================================db.open, " + str(DB.connections)) if self.DataName == 'ORACLE': self.__conn = self.pool.acquire() self.__cursor = self.__conn.cursor() elif self.DataName == 'POSTGRESQL': self.__conn = self.pool.getconn() self.__cursor = self.__conn.cursor() else: # 默认为Mysql self.__conn = self.pool.get_connection() self.__cursor = self.__conn.cursor(buffered=True) #self.__conn.autocommit=True self.__conn.autocommit = auto self.cnx = self.__conn self.cur = self.__cursor except: raise BaseError(702) # 无法获得连接池 def close(self): DB.connections -= 1 # print("===================================db.close, " + str(DB.connections)) #关闭游标和数据库连接 self.__conn.commit() if self.__cursor is not None: self.__cursor.close() if self.DataName == 'POSTGRESQL': self.pool.putconn(self.__conn) #将数据库连接放回连接池中 else: self.__conn.close() # print("===================================db.close end ") def begin(self): self.__conn.autocommit = False def commit(self): self.__conn.commit() def rollback(self): self.__conn.rollback() #--------------------------------------------------------------------------- def findBySql(self, sql, params={}, limit=0, join='AND', lock=False): """ 自定义sql语句查找 limit = 是否需要返回多少行 params = dict(field=value) join = 'AND | OR' """ try: cursor = self.__getCursor() sql = self.__joinWhere(sql, params, join) cursor.execute(sql, tuple(params.values())) rows = cursor.fetchmany( size=limit) if limit > 0 else cursor.fetchall() result = [dict(zip(cursor.column_names, row)) for row in rows] if rows else None return result except: raise BaseError(706) def countBySql(self, sql, params={}, join='AND'): # 自定义sql 统计影响行数 try: cursor = self.__getCursor() sql = self.__joinWhere(sql, params, join) cursor.execute(sql, tuple(params.values())) result = cursor.fetchone() return result[0] if result else 0 except: raise BaseError(707) def deleteByCond(self, table, cond): # 删除数据 try: sql = "DELETE FROM %s where %s" % (table, cond) cursor = self.__getCursor() self.__display_Debug_IO(sql, ()) #DEBUG cursor.execute(sql) #self.__conn.commit() # return cursor.rowcount #except: # raise BaseError(704) except Exception as err: raise BaseError(704, err._full_msg) #def updateByPk(self,table,data,id,pk='id'): # # 根据主键更新,默认是id为主键 # return self.updateByAttr(table,data,{pk:id}) def deleteByAttr(self, table, params={}, join='AND'): # 删除数据 try: fields = ','.join(k + '=%s' for k in params.keys()) sql = "DELETE FROM %s " % table sql = self.__joinWhere(sql, params, join) cursor = self.__getCursor() self.__display_Debug_IO(sql, tuple(params.values())) #DEBUG cursor.execute(sql, tuple(params.values())) #self.__conn.commit() # return cursor.rowcount #except: # raise BaseError(704) except Exception as err: raise BaseError(704, err._full_msg) def deleteByPk(self, table, id, pk='id'): # 根据主键删除,默认是id为主键 return self.deleteByAttr(table, {pk: id}) def findByAttr(self, table, criteria={}): # 根据条件查找一条记录 return self.__query(table, criteria) def findByPk(self, table, id, pk='id'): return self.findByAttr(table, {'where': pk + '=' + str(id)}) def findAllByAttr(self, table, criteria={}): # 根据条件查找记录 return self.__query(table, criteria, True) def exit(self, table, params={}, join='AND'): # 判断是否存在 return self.count(table, params, join) > 0 # 公共的方法 ------------------------------------------------------------------------------------- def count(self, table, params={}, join='AND'): # 根据条件统计行数 try: sql = 'SELECT COUNT(*) FROM %s' % table if params: where, whereValues = self.__contact_where(params) sqlWhere = ' WHERE ' + where if where else '' sql += sqlWhere #sql = self.__joinWhere(sql,params,join) cursor = self.__getCursor() self.__display_Debug_IO(sql, tuple(whereValues)) #DEBUG if self.DataName == 'ORACLE': cursor.execute(sql % tuple(whereValues)) else: cursor.execute(sql, tuple(whereValues)) #cursor.execute(sql,tuple(params.values())) result = cursor.fetchone() return result[0] if result else 0 #except: # raise BaseError(707) except Exception as err: try: raise BaseError(707, err._full_msg) except: raise BaseError(707) def getToListByPk(self, table, criteria={}, id=None, pk='id'): # 根据条件查找记录返回List if ('where' not in criteria) and (id is not None): criteria['where'] = pk + "='" + str(id) + "'" return self.__query(table, criteria, isDict=False) def getAllToList(self, table, criteria={}, id=None, pk='id', join='AND'): # 根据条件查找记录返回List if ('where' not in criteria) and (id is not None): criteria['where'] = pk + "='" + str(id) + "'" return self.__query(table, criteria, all=True, isDict=False) def getToObjectByPk(self, table, criteria={}, id=None, pk='id'): # 根据条件查找记录返回Object if ('where' not in criteria) and (id is not None): criteria['where'] = pk + "='" + str(id) + "'" return self.__query(table, criteria) def getAllToObject(self, table, criteria={}, id=None, pk='id', join='AND'): # 根据条件查找记录返回Object if ('where' not in criteria) and (id is not None): criteria['where'] = pk + "='" + str(id) + "'" return self.__query(table, criteria, all=True) def insert(self, table, data, commit=True): # 新增一条记录 try: ''' 从data中分离含用SQL函数的字字段到funData字典中, 不含SQL函数的字段到newData ''' funData, newData = self.__split_expression(data) funFields = '' funValues = '' # 拼不含SQL函数的字段及值 fields = ', '.join(k for k in newData.keys()) values = ', '.join(("%s", ) * len(newData)) # 拼含SQL函数的字段及值 if funData: funFields = ','.join(k for k in funData.keys()) funValues = ','.join(v for v in funData.values()) # 合并所有字段及值 fields += ', ' + funFields if funFields else '' values += ', ' + funValues if funValues else '' sql = 'INSERT INTO %s (%s) VALUES (%s)' % (table, fields, values) cursor = self.__getCursor() for (k, v) in newData.items(): try: if isinstance(v, str): newData[k] = "'%s'" % (v, ) if v == None: newData[k] = "null" except: pass self.__display_Debug_IO(sql, tuple(newData.values())) #DEBUG sql = sql % tuple(newData.values()) if self.DataName == 'POSTGRESQL': sql += ' RETURNING id' cursor.execute(sql) #if self.DataName=='ORACLE': #sql= sql % tuple(newData.values()) #cursor.execute(sql) #else : #cursor.execute(sql,tuple(newData.values())) if self.DataName == 'ORACLE': # 1. commit 一定要为假 # 2. Oracle Sequence 的命名规范为: [用户名.]SEQ_表名_ID # 3. 每张主表都应该有ID t_list = table.split('.') if len(t_list) > 1: SEQ_Name = t_list[0] + '.SEQ_' + t_list[1] + '_ID' else: SEQ_Name = 'SEQ_' + t_list[0] + '_ID' cursor.execute('SELECT %s.CURRVAL FROM dual' % SEQ_Name.upper()) result = cursor.fetchone() insert_id = result[0] if result else 0 #insert_id=cursor.rowcount elif self.DataName == 'MYSQL': insert_id = cursor.lastrowid elif self.DataName == 'POSTGRESQL': item = cursor.fetchone() insert_id = item[0] if commit: self.commit() return insert_id except Exception as err: try: raise BaseError(705, err._full_msg) except: raise BaseError(705, err.args) def update(self, table, data, params={}, join='AND', commit=True, lock=True): # 更新数据 try: fields, values = self.__contact_fields(data) if params: where, whereValues = self.__contact_where(params) values.extend(whereValues) if whereValues else values sqlWhere = ' WHERE ' + where if where else '' cursor = self.__getCursor() if commit: self.begin() if lock: sqlSelect = "SELECT %s From %s %s for update" % (','.join( tuple(list(params.keys()))), table, sqlWhere) sqlSelect = sqlSelect % tuple(whereValues) cursor.execute(sqlSelect) # 加行锁 #cursor.execute(sqlSelect,tuple(whereValues)) # 加行锁 sqlUpdate = "UPDATE %s SET %s " % (table, fields) + sqlWhere for index, val in enumerate(values): if isinstance(val, str): values[index] = "'" + val + "'" if val == None: values[index] = "null" self.__display_Debug_IO(sqlUpdate, tuple(values)) #DEBUG sqlUpdate = sqlUpdate % tuple(values) cursor.execute(sqlUpdate) #cursor.execute(sqlUpdate,tuple(values)) if commit: self.commit() return cursor.rowcount except Exception as err: try: raise BaseError(705, err._full_msg) except: raise BaseError(705, err.args) def updateByPk(self, table, data, id, pk='id', commit=True, lock=True): # 根据主键更新,默认是id为主键 return self.update(table, data, {pk: id}, commit=commit, lock=lock) def delete(self, table, params={}, join='AND', commit=True, lock=True): # 更新数据 try: data = {} fields, values = self.__contact_fields(data) if params: where, whereValues = self.__contact_where(params) values.extend(whereValues) if whereValues else values sqlWhere = ' WHERE ' + where if where else '' cursor = self.__getCursor() if commit: self.begin() #if lock : #sqlSelect="SELECT %s From %s %s for update" % (','.join(tuple(list(params.keys()))),table,sqlWhere) #sqlSelect=sqlSelect % tuple(whereValues) #cursor.execute(sqlSelect) # 加行锁 ##cursor.execute(sqlSelect,tuple(whereValues)) # 加行锁 sqlDelete = "DELETE FROM %s %s" % (table, sqlWhere) for index, val in enumerate(values): if isinstance(val, str): values[index] = "'" + val + "'" self.__display_Debug_IO(sqlDelete, tuple(values)) #DEBUG sqlDelete = sqlDelete % tuple(values) cursor.execute(sqlDelete) #cursor.execute(sqlUpdate,tuple(values)) if commit: self.commit() return cursor.rowcount except Exception as err: try: raise BaseError(705, err._full_msg) except: raise BaseError(705, err.args) def deleteByPk(self, table, id, pk='id', commit=True, lock=True): # 根据主键更新,默认是id为主键 return self.delete(table, {pk: id}, commit=commit, lock=lock) # 内部私有的方法 ------------------------------------------------------------------------------------- def __display_Debug_IO(self, sql, params): # 不输出SQL语句,以减少屏显时间 return if self.debug: debug_now_time = datetime.datetime.now().strftime( '%Y-%m-%d %H:%M:%S') print('[S ' + debug_now_time + ' SQL:] ' + (sql % params) if params else sql) def __get_connection(self): return self.pool.get_connection() def __getCursor(self): """获取游标""" if self.__cursor is None: self.__cursor = self.__conn.cursor() return self.__cursor def getCursor(self): """获取游标""" if self.__cursor is None: self.__cursor = self.__conn.cursor() return self.__cursor def __joinWhere(self, sql, params, join): # 转换params为where连接语句 if params: funParams = {} newParams = {} newWhere = '' funWhere = '' # 从params中分离含用SQL函数的字字段到Params字典中 for (k, v) in params.items(): if 'str' in str(type(v)) and '{{' == v[:2] and '}}' == v[-2:]: funParams[k] = v[2:-2] else: newParams[k] = v # 拼 newParams 条件 keys, _keys = self.__tParams(newParams) newWhere = ' AND '.join(k + '=' + _k for k, _k in zip( keys, _keys)) if join == 'AND' else ' OR '.join( k + '=' + _k for k, _k in zip(keys, _keys)) # 拼 funParams 条件 if funParams: funWhere = ' AND '.join( k + '=' + v for k, v in funParams.items()) if join == 'AND' else ' OR '.join( k + '=' + v for k, v in funParams.items()) # 拼最终的 where where = ( (newWhere + ' AND ' if newWhere else '') + funWhere if funWhere else newWhere) if join == 'AND' else ( (newWhere + ' OR ' if newWhere else '') + funWhere if funWhere else newWhere) #-------------------------------------- #keys,_keys = self.__tParams(params) #where = ' AND '.join(k+'='+_k for k,_k in zip(keys,_keys)) if join == 'AND' else ' OR '.join(k+'='+_k for k,_k in zip(keys,_keys)) sql += ' WHERE ' + where return sql def __tParams(self, params): keys = [k if k[:2] != '{{' else k[2:-2] for k in params.keys()] _keys = ['%s' for k in params.keys()] return keys, _keys def __query(self, table, criteria, all=False, isDict=True, join='AND'): ''' table : 表名 criteria : 查询条件dict all : 是否返回所有数据,默认为False只返回一条数据,当为真是返回所有数据 isDict : 返回格式是否为字典,默认为True ,即字典否则为数组 ''' try: if all is not True: criteria['limit'] = 1 # 只输出一条 sql, params = self.__contact_sql(table, criteria, join) #拼sql及params ''' # 当Where为多个查询条件时,拼查询条件 key 的 valuse 值 if 'where' in criteria and 'dict' in str(type(criteria['where'])) : params = criteria['where'] #params = tuple(params.values()) where ,whereValues = self.__contact_where(params) sql+= ' WHERE '+where if where else '' params=tuple(whereValues) else : params = None ''' #__contact_where(params,join='AND') cursor = self.__getCursor() self.__display_Debug_IO(sql, params) #DEBUG #if self.DataName=="ORACLE": #sql="select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % () #pass cursor.execute(sql, params if params else ()) rows = cursor.fetchall() if all else cursor.fetchone() if isDict: result = [dict(zip(cursor.column_names, row)) for row in rows] if all else dict( zip(cursor.column_names, rows)) if rows else {} else: result = [row for row in rows] if all else rows if rows else [] return result except Exception as err: try: raise BaseError(706, err._full_msg) except: raise BaseError(706) def __contact_sql(self, table, criteria, join='AND'): sql = 'SELECT ' if criteria and type(criteria) is dict: #select fields if 'select' in criteria: fields = criteria['select'].split(',') sql += ','.join( field.strip()[2:-2] if '{{' == field.strip()[:2] and '}}' == field.strip()[-2:] else field.strip() for field in fields) else: sql += ' * ' #table sql += ' FROM %s' % table #where whereValues = None if 'where' in criteria: if 'str' in str(type(criteria['where'])): # 当值为String时,即单一Key时 sql += ' WHERE ' + criteria['where'] else: # 当值为dict时,即一组key时 params = criteria['where'] #sql+= self.__joinWhere('',params,join) #sql+=self.__contact_where(params,join) where, whereValues = self.__contact_where(params) sql += ' WHERE ' + where if where else '' #sql=sql % tuple(whereValues) #group by if 'group' in criteria: sql += ' GROUP BY ' + criteria['group'] #having if 'having' in criteria: sql += ' HAVING ' + criteria['having'] if self.DataName == 'MYSQL': #order by if 'order' in criteria: sql += ' ORDER BY ' + criteria['order'] #limit if 'limit' in criteria: sql += ' LIMIT ' + str(criteria['limit']) #offset if 'offset' in criteria: sql += ' OFFSET ' + str(criteria['offset']) elif (self.DataName == 'POSTGRESQL'): #order by if 'order' in criteria: sql += ' ORDER BY ' + criteria['order'] if 'limit' in criteria: # 取 offset,rowcount arrLimit = (str( criteria['limit']).split('limit ').pop()).split(',') strOffset = arrLimit[0] try: strRowcount = arrLimit[1] except: strOffset = '0' strRowcount = '1' sql += ' LIMIT %s OFFSET %s' % (strRowcount, strOffset) elif (self.DataName == 'ORACLE') and ('limit' in criteria): # 取 offset,rowcount arrLimit = (str( criteria['limit']).split('limit ').pop()).split(',') strOffset = arrLimit[0] try: strRowcount = arrLimit[1] except: strOffset = '0' strRowcount = '1' # 处理 order by if 'order' in criteria: strOrder = criteria['order'] else: strOrder = 'ROWNUM' # 以下Sql是针对 Oracle 的大数据查询效率 sql = "select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % ( strOrder, sql, strOffset, strRowcount) elif (self.DataName == 'ORACLE') and ('order' in criteria): sql += ' ORDER BY ' + criteria['order'] else: sql += ' * FROM %s' % table return sql, whereValues # 将字符串和表达式分离 def __split_expression(self, data): funData = {} newData = {} funFields = '' # 从data中移出含用SQL函数的字字段到funData字典中 for (k, v) in data.items(): if 'str' in str(type(v)) and '{{' == v[:2] and '}}' == v[-2:]: funData[k] = v[2:-2] else: newData[k] = v return (funData, newData) # 拼Update字段 def __contact_fields(self, data): funData, newData = self.__split_expression(data) if funData: funFields = ','.join(k + '=%s' % (v) for k, v in funData.items()) fields = ','.join(k + '=%s' for k in newData.keys()) # fields 与 funFields 合并 if funData: fields = ','.join([fields, funFields]) if fields else funFields values = list(newData.values()) return (fields, values) def __hasKeyword(self, key): if '{{}}' in key: return True if 'in (' in key: return True if 'like ' in key: return True if '>' in key: return True if '<' in key: return True return False # 拼Where条件 def __contact_where(self, params, join='AND'): funParams, newParams = self.__split_expression(params) # 拼 newParams 条件 keys, _keys = self.__tParams(newParams) newWhere = ' AND '.join( k + '=' + _k for k, _k in zip(keys, _keys)) if join == 'AND' else ' OR '.join( k + '=' + _k for k, _k in zip(keys, _keys)) values = list(newParams.values()) # 拼 funParams 条件 #funWhere = ' AND '.join(('`' if k else '') +k+('`' if k else '')+ (' ' if self.__hasKeyword(v) else '=') +v for k,v in funParams.items()) if join == 'AND' else ' OR '.join('`'+k+'`'+(' ' if self.__hasKeyword(v) else '=')+v for k,v in funParams.items()) funWhere = ' AND '.join( k + (' ' if self.__hasKeyword(v) else '=' if k else '') + v for k, v in funParams.items()) if join == 'AND' else ' OR '.join( k + (' ' if self.__hasKeyword(v) else '=' if k else '') + v for k, v in funParams.items()) # 拼最终的 where where = ((newWhere + ' AND ' if newWhere else '') + funWhere if funWhere else newWhere) if join == 'AND' else ( (newWhere + ' OR ' if newWhere else '') + funWhere if funWhere else newWhere) return (where, values) def get_ids(self, list): #从getAllToList返回中提取id try: test = list[0][0] dimension = 2 except: dimension = 1 ids = [] if dimension > 1: for i in range(len(list)): ids.append(str(list[i][0])) else: for i in range(len(list)): ids.append(str(list[i])) return ','.join(ids)
class Psycopg2Backend(Backend): """Backend for accessing data stored in a Postgres database """ display_name = "PostgreSQL" connection_pool = None auto_create_extensions = True def __init__(self, connection_params): super().__init__(connection_params) if self.connection_pool is None: self._create_connection_pool() if self.auto_create_extensions: self._create_extensions() def _create_connection_pool(self): try: self.connection_pool = ThreadedConnectionPool( 1, 16, **self.connection_params) except Error as ex: raise BackendError(str(ex)) from ex def _create_extensions(self): for ext in EXTENSIONS: try: query = "CREATE EXTENSION IF NOT EXISTS {}".format(ext) with self.execute_sql_query(query): pass except OperationalError: warnings.warn("Database is missing extension {}".format(ext)) def create_sql_query(self, table_name, fields, filters=(), group_by=None, order_by=None, offset=None, limit=None, use_time_sample=None): sql = ["SELECT", ', '.join(fields), "FROM", table_name] if use_time_sample is not None: sql.append("TABLESAMPLE system_time(%i)" % use_time_sample) if filters: sql.extend(["WHERE", " AND ".join(filters)]) if group_by is not None: sql.extend(["GROUP BY", ", ".join(group_by)]) if order_by is not None: sql.extend(["ORDER BY", ",".join(order_by)]) if offset is not None: sql.extend(["OFFSET", str(offset)]) if limit is not None: sql.extend(["LIMIT", str(limit)]) return " ".join(sql) @contextmanager def execute_sql_query(self, query, params=None): connection = self.connection_pool.getconn() cur = connection.cursor() try: utfquery = cur.mogrify(query, params).decode('utf-8') log.debug("Executing: %s", utfquery) t = time() cur.execute(query, params) yield cur log.info("%.2f ms: %s", 1000 * (time() - t), utfquery) finally: connection.commit() self.connection_pool.putconn(connection) def quote_identifier(self, name): return '"%s"' % name def unquote_identifier(self, quoted_name): if quoted_name.startswith('"'): return quoted_name[1:len(quoted_name) - 1] else: return quoted_name def list_tables_query(self, schema=None): if schema: schema_clause = "AND n.nspname = '{}'".format(schema) else: schema_clause = "AND pg_catalog.pg_table_is_visible(c.oid)" return """SELECT n.nspname as "Schema", c.relname AS "Name" FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind IN ('r','v','m','S','f','') AND n.nspname <> 'pg_catalog' AND n.nspname <> 'information_schema' AND n.nspname !~ '^pg_toast' {} AND NOT c.relname LIKE '\\_\\_%' ORDER BY 1;""".format(schema_clause) def create_variable(self, field_name, field_metadata, type_hints, inspect_table=None): if field_name in type_hints: var = type_hints[field_name] else: var = self._guess_variable(field_name, field_metadata, inspect_table) field_name_q = self.quote_identifier(field_name) if var.is_continuous: if isinstance(var, TimeVariable): var.to_sql = ToSql("extract(epoch from {})" .format(field_name_q)) else: var.to_sql = ToSql("({})::double precision" .format(field_name_q)) else: # discrete or string var.to_sql = ToSql("({})::text" .format(field_name_q)) return var def _guess_variable(self, field_name, field_metadata, inspect_table): type_code = field_metadata[0] FLOATISH_TYPES = (700, 701, 1700) # real, float8, numeric INT_TYPES = (20, 21, 23) # bigint, int, smallint CHAR_TYPES = (25, 1042, 1043,) # text, char, varchar BOOLEAN_TYPES = (16,) # bool DATE_TYPES = (1082, 1114, 1184, ) # date, timestamp, timestamptz # time, timestamp, timestamptz, timetz TIME_TYPES = (1083, 1114, 1184, 1266,) if type_code in FLOATISH_TYPES: return ContinuousVariable(field_name) if type_code in TIME_TYPES + DATE_TYPES: tv = TimeVariable(field_name) tv.have_date |= type_code in DATE_TYPES tv.have_time |= type_code in TIME_TYPES return tv if type_code in INT_TYPES: # bigint, int, smallint if inspect_table: values = self.get_distinct_values(field_name, inspect_table) if values: return DiscreteVariable(field_name, values) return ContinuousVariable(field_name) if type_code in BOOLEAN_TYPES: return DiscreteVariable(field_name, ['false', 'true']) if type_code in CHAR_TYPES: if inspect_table: values = self.get_distinct_values(field_name, inspect_table) if values: return DiscreteVariable(field_name, values) return StringVariable(field_name) def count_approx(self, query): sql = "EXPLAIN " + query with self.execute_sql_query(sql) as cur: s = ''.join(row[0] for row in cur.fetchall()) return int(re.findall(r'rows=(\d*)', s)[0]) def __getstate__(self): # Drop connection_pool from state as it cannot be pickled state = dict(self.__dict__) state.pop('connection_pool', None) return state def __setstate__(self, state): # Create a new connection pool if none exists self.__dict__.update(state) if self.connection_pool is None: self._create_connection_pool()
class Psycopg2Backend(Backend): """Backend for accessing data stored in a Postgres database """ display_name = "PostgreSQL" connection_pool = None auto_create_extensions = True def __init__(self, connection_params): super().__init__(connection_params) if self.connection_pool is None: self._create_connection_pool() self.missing_extension = [] if self.auto_create_extensions: self._create_extensions() def _create_connection_pool(self): try: self.connection_pool = ThreadedConnectionPool( 1, 16, **self.connection_params) except Error as ex: raise BackendError(str(ex)) from ex def _create_extensions(self): for ext in EXTENSIONS: try: query = "CREATE EXTENSION IF NOT EXISTS {}".format(ext) with self.execute_sql_query(query): pass except BackendError: warnings.warn("Database is missing extension {}".format(ext)) self.missing_extension.append(ext) def create_sql_query(self, table_name, fields, filters=(), group_by=None, order_by=None, offset=None, limit=None, use_time_sample=None): sql = ["SELECT", ', '.join(fields), "FROM", table_name] if use_time_sample is not None: sql.append("TABLESAMPLE system_time(%i)" % use_time_sample) if filters: sql.extend(["WHERE", " AND ".join(filters)]) if group_by is not None: sql.extend(["GROUP BY", ", ".join(group_by)]) if order_by is not None: sql.extend(["ORDER BY", ",".join(order_by)]) if offset is not None: sql.extend(["OFFSET", str(offset)]) if limit is not None: sql.extend(["LIMIT", str(limit)]) return " ".join(sql) @contextmanager def execute_sql_query(self, query, params=None): connection = self.connection_pool.getconn() cur = connection.cursor() try: utfquery = cur.mogrify(query, params).decode('utf-8') log.debug("Executing: %s", utfquery) t = time() cur.execute(query, params) yield cur log.info("%.2f ms: %s", 1000 * (time() - t), utfquery) except (Error, ProgrammingError) as ex: raise BackendError(str(ex)) from ex finally: connection.commit() self.connection_pool.putconn(connection) def quote_identifier(self, name): return '"%s"' % name def unquote_identifier(self, quoted_name): if quoted_name.startswith('"'): return quoted_name[1:len(quoted_name) - 1] else: return quoted_name def list_tables_query(self, schema=None): if schema: schema_clause = "AND n.nspname = '{}'".format(schema) else: schema_clause = "AND pg_catalog.pg_table_is_visible(c.oid)" return """SELECT n.nspname as "Schema", c.relname AS "Name" FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind IN ('r','v','m','S','f','') AND n.nspname <> 'pg_catalog' AND n.nspname <> 'information_schema' AND n.nspname !~ '^pg_toast' {} AND NOT c.relname LIKE '\\_\\_%' ORDER BY 1;""".format(schema_clause) def create_variable(self, field_name, field_metadata, type_hints, inspect_table=None): if field_name in type_hints: var = type_hints[field_name] else: var = self._guess_variable(field_name, field_metadata, inspect_table) field_name_q = self.quote_identifier(field_name) if var.is_continuous: if isinstance(var, TimeVariable): var.to_sql = ToSql("extract(epoch from {})" .format(field_name_q)) else: var.to_sql = ToSql("({})::double precision" .format(field_name_q)) else: # discrete or string var.to_sql = ToSql("({})::text" .format(field_name_q)) return var def _guess_variable(self, field_name, field_metadata, inspect_table): type_code = field_metadata[0] FLOATISH_TYPES = (700, 701, 1700) # real, float8, numeric INT_TYPES = (20, 21, 23) # bigint, int, smallint CHAR_TYPES = (25, 1042, 1043,) # text, char, varchar BOOLEAN_TYPES = (16,) # bool DATE_TYPES = (1082, 1114, 1184, ) # date, timestamp, timestamptz # time, timestamp, timestamptz, timetz TIME_TYPES = (1083, 1114, 1184, 1266,) if type_code in FLOATISH_TYPES: return ContinuousVariable.make(field_name) if type_code in TIME_TYPES + DATE_TYPES: tv = TimeVariable.make(field_name) tv.have_date |= type_code in DATE_TYPES tv.have_time |= type_code in TIME_TYPES return tv if type_code in INT_TYPES: # bigint, int, smallint if inspect_table: values = self.get_distinct_values(field_name, inspect_table) if values: return DiscreteVariable.make(field_name, values) return ContinuousVariable.make(field_name) if type_code in BOOLEAN_TYPES: return DiscreteVariable.make(field_name, ['false', 'true']) if type_code in CHAR_TYPES: if inspect_table: values = self.get_distinct_values(field_name, inspect_table) # remove trailing spaces values = [v.rstrip() for v in values] if values: return DiscreteVariable.make(field_name, values) return StringVariable.make(field_name) def count_approx(self, query): sql = "EXPLAIN " + query with self.execute_sql_query(sql) as cur: s = ''.join(row[0] for row in cur.fetchall()) return int(re.findall(r'rows=(\d*)', s)[0]) def __getstate__(self): # Drop connection_pool from state as it cannot be pickled state = dict(self.__dict__) state.pop('connection_pool', None) return state def __setstate__(self, state): # Create a new connection pool if none exists self.__dict__.update(state) if self.connection_pool is None: self._create_connection_pool()
class Database(): def __init__(self, config, verbose_start=False): self.config = config if verbose_start: print('Connecting to database...'.ljust(5), end='') # get DB parameters try: self.database = config.getProperty('Database', 'name').lower() self.host = config.getProperty('Database', 'host') self.port = config.getProperty('Database', 'port') self.user = config.getProperty('Database', 'user').lower() self.password = config.getProperty('Database', 'password') except Exception as e: if verbose_start: LogDecorator.print_status('fail') raise Exception(f'Incomplete database credentials provided in configuration file (message: "{str(e)}").') try: self._createConnectionPool() except Exception as e: if verbose_start: LogDecorator.print_status('fail') raise Exception(f'Could not connect to database (message: "{str(e)}").') if verbose_start: LogDecorator.print_status('ok') def _createConnectionPool(self): self.connectionPool = ThreadedConnectionPool( 1, self.config.getProperty('Database', 'max_num_connections', type=int, fallback=20), host=self.host, database=self.database, port=self.port, user=self.user, password=self.password, connect_timeout=2 ) def runServer(self): ''' Dummy function for compatibility reasons ''' return @contextmanager def _get_connection(self): conn = self.connectionPool.getconn() conn.autocommit = True try: yield conn finally: self.connectionPool.putconn(conn, close=False) def execute(self, query, arguments, numReturn=None): with self._get_connection() as conn: cursor = conn.cursor(cursor_factory=RealDictCursor) # execute statement try: cursor.execute(query, arguments) conn.commit() except Exception as e: if not conn.closed: conn.rollback() # self.connectionPool.putconn(conn, close=False) #TODO: this still causes connection to close conn = self.connectionPool.getconn() # retry execution try: cursor = conn.cursor(cursor_factory=RealDictCursor) cursor.execute(query, arguments) conn.commit() except: if not conn.closed: conn.rollback() print(e) # get results try: returnValues = [] if numReturn is None: # cursor.close() return elif numReturn == 'all': returnValues = cursor.fetchall() # cursor.close() return returnValues else: for _ in range(numReturn): rv = cursor.fetchone() if rv is None: return returnValues returnValues.append(rv) # cursor.close() return returnValues except Exception as e: print(e) def execute_cursor(self, query, arguments): with self._get_connection() as conn: cursor = conn.cursor(cursor_factory=RealDictCursor) try: cursor.execute(query, arguments) conn.commit() return cursor except: if not conn.closed: conn.rollback() # cursor.close() # retry execution conn = self.connectionPool.getconn() try: cursor = conn.cursor(cursor_factory=RealDictCursor) cursor.execute(query, arguments) conn.commit() except Exception as e: if not conn.closed: conn.rollback() print(e) def insert(self, query, values): with self._get_connection() as conn: cursor = conn.cursor() try: execute_values(cursor, query, values) conn.commit() except Exception as e: if not conn.closed: conn.rollback() # cursor.close() # retry execution conn = self.connectionPool.getconn() try: cursor = conn.cursor(cursor_factory=RealDictCursor) execute_values(cursor, query, values) conn.commit() except Exception as e: if not conn.closed: conn.rollback() print(e)
rows = cur.fetchall() # all rows in table if len(rows) < 1: return render_template('main.html', model=json.dumps(active_orders)) phone = rows[0][0] return render_template('main.html', model=json.dumps(active_orders), phone=phone) except Exception, e: print Exception, ":", e return render_template('main.html', model="") finally: cur.close() pg_pool.putconn(conn) # return 'Hello!' + user_id # API @app.route('/login', methods=['POST']) def login(): phone = request.form['phone'].encode(encoding='utf-8') password_user = request.form['password'].encode(encoding='utf-8') try: conn = pg_pool.getconn() if conn is None: return '["false","db_connect_error"]'
VALUES (%s, %s, %s, %s) ''' def save_review(conn, user_id, data): try: cur = conn.cursor() cur.execute(insert_review_sql, (user_id, data['book_id'], data['rating'], data['review'])) conn.commit() except Exception as err: print('Error Saving review:', err, user_id) if __name__ == "__main__": global tcp params = config() tcp = ThreadedConnectionPool(1, 10, **params) conn = tcp.getconn() review_info = [{ 'book_id': 24, 'user_id': 'AG3HZ5KORQIXZHUBLNMZNU35SCEA', 'name': '123', 'rating': 4.0, 'review': "This" }] save_user(conn, review_info) tcp.putconn(conn)
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} flavor = Flavor(ilike=True) def __new__(cls, name='template1'): if name in cls._databases: return cls._databases[name] return DatabaseInterface.__new__(cls, name=name) def __init__(self, name='template1'): super(Database, self).__init__(name=name) self._databases.setdefault(name, self) self._search_path = None self._current_user = None @classmethod def dsn(cls, name): uri = parse_uri(config.get('database', 'uri')) assert uri.scheme == 'postgresql' host = uri.hostname and "host=%s" % uri.hostname or '' port = uri.port and "port=%s" % uri.port or '' name = "dbname=%s" % name user = uri.username and "user=%s" % uri.username or '' password = ("password=%s" % urllib.unquote_plus(uri.password) if uri.password else '') return '%s %s %s %s %s' % (host, port, name, user, password) def connect(self): if self._connpool is not None: return self logger.info('connect to "%s"', self.name) minconn = config.getint('database', 'minconn', default=1) maxconn = config.getint('database', 'maxconn', default=64) self._connpool = ThreadedConnectionPool(minconn, maxconn, self.dsn(self.name)) return self def get_connection(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() for count in range(config.getint('database', 'retry'), -1, -1): try: conn = self._connpool.getconn() break except PoolError: if count and not self._connpool.closed: logger.info('waiting a connection') time.sleep(1) continue raise if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) if readonly: cursor = conn.cursor() cursor.execute('SET TRANSACTION READ ONLY') return conn def put_connection(self, connection, close=False): self._connpool.putconn(connection, close=close) def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None @classmethod def create(cls, connection, database_name): cursor = connection.cursor() cursor.execute('CREATE DATABASE "' + database_name + '" ' 'TEMPLATE template0 ENCODING \'unicode\'') connection.commit() cls._list_cache = None def drop(self, connection, database_name): cursor = connection.cursor() cursor.execute('DROP DATABASE "' + database_name + '"') Database._list_cache = None def get_version(self, connection): if self.name not in self._version_cache: cursor = connection.cursor() cursor.execute('SELECT version()') version, = cursor.fetchone() self._version_cache[self.name] = tuple( map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.name] @staticmethod def dump(database_name): from trytond.tools import exec_command_pipe cmd = ['pg_dump', '--format=c', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: # if db_password is set in configuration we should pass # an environment variable PGPASSWORD to our subprocess # see libpg documentation env['PGPASSWORD'] = uri.password cmd.append(database_name) pipe = exec_command_pipe(*tuple(cmd), env=env) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception('Couldn\'t dump database!') return data @staticmethod def restore(database_name, data): from trytond.tools import exec_command_pipe database = Database().connect() connection = database.get_connection(autocommit=True) database.create(connection, database_name) database.close() cmd = ['pg_restore', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: env['PGPASSWORD'] = uri.password cmd.append('--dbname=' + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam() with open(tmpfile, 'wb') as fp: fp.write(data) args2 = list(args2) args2.append(' ' + tmpfile) args2 = tuple(args2) pipe = exec_command_pipe(*args2, env=env) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception('Couldn\'t restore database') database = Database(database_name).connect() cursor = database.get_connection().cursor() if not database.test(): cursor.close() database.close() raise Exception('Couldn\'t restore database!') cursor.close() database.close() Database._list_cache = None return True def list(self): now = time.time() timeout = config.getint('session', 'timeout') res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res connection = self.get_connection() cursor = connection.cursor() cursor.execute('SELECT datname FROM pg_database ' 'WHERE datistemplate = false ORDER BY datname') res = [] for db_name, in cursor: try: with connect(self.dsn(db_name)) as conn: if self._test(conn): res.append(db_name) except Exception: continue self.put_connection(connection) Database._list_cache = res Database._list_cache_timestamp = now return res def init(self): from trytond.modules import get_module_info connection = self.get_connection() cursor = connection.cursor() sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ('ir', 'res'): state = 'uninstalled' if module in ('ir', 'res'): state = 'to install' info = get_module_info(module) cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')') module_id = cursor.fetchone()[0] cursor.execute( 'INSERT INTO ir_module ' '(id, create_uid, create_date, name, state) ' 'VALUES (%s, %s, now(), %s, %s)', (module_id, 0, module, state)) for dependency in info.get('depends', []): cursor.execute( 'INSERT INTO ir_module_dependency ' '(create_uid, create_date, module, name) ' 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency)) connection.commit() self.put_connection(connection) def test(self): connection = self.get_connection() is_tryton_database = self._test(connection) self.put_connection(connection) return is_tryton_database @classmethod def _test(cls, connection): cursor = connection.cursor() cursor.execute( 'SELECT 1 FROM information_schema.tables ' 'WHERE table_name IN %s', (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu', 'res_user', 'res_group', 'ir_module', 'ir_module_dependency', 'ir_translation', 'ir_lang'), )) return len(cursor.fetchall()) != 0 def nextid(self, connection, table): cursor = connection.cursor() cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')") return cursor.fetchone()[0] def setnextid(self, connection, table, value): cursor = connection.cursor() cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value) def currid(self, connection, table): cursor = connection.cursor() cursor.execute('SELECT last_value FROM "' + table + '_id_seq"') return cursor.fetchone()[0] def lock(self, connection, table): cursor = connection.cursor() cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table) def has_constraint(self): return True def has_multirow_insert(self): return True def get_table_schema(self, connection, table_name): cursor = connection.cursor() for schema in self.search_path: cursor.execute( 'SELECT 1 ' 'FROM information_schema.tables ' 'WHERE table_name = %s AND table_schema = %s', (table_name, schema)) if cursor.rowcount: return schema @property def current_user(self): if self._current_user is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SELECT current_user') self._current_user = cursor.fetchone()[0] finally: self.put_connection(connection) return self._current_user @property def search_path(self): if self._search_path is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SHOW search_path') path, = cursor.fetchone() special_values = { 'user': self.current_user, } self._search_path = [ unescape_quote( replace_special_values(p.strip(), **special_values)) for p in path.split(',') ] finally: self.put_connection(connection) return self._search_path
class Database(object): def __init__(self, db_connection, connection_number, app, min_db_connections=1): self.db_connection = db_connection self.min_connections = min_db_connections self.connection_number = connection_number self.pool = ThreadedConnectionPool(min_db_connections, connection_number, db_connection) self.app = app self.last_seen_process_id = os.getpid() self.needs_change = True def dispose(self): self.pool.closeall() def get_connection(self): if self.needs_change: current_pid = os.getpid() if not (current_pid == self.last_seen_process_id): self.last_seen_process_id = current_pid self.pool.closeall() self.pool = ThreadedConnectionPool(self.min_connections, self.connection_number, self.db_connection) self.needs_change = False return self.pool.getconn() def run_script(self, script): from app import app success = True conn = self.get_connection() try: cursor = conn.cursor() cursor.execute(script) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(script)) conn.rollback() self.pool.putconn(conn) success = False return success def add_tracks(self, dicts, table): conn = self.get_connection() cursor = conn.cursor() for track in tracks: columns = track.keys() values = [track[column] for column in columns] sql = "INSERT INTO tracks (%s) VALUES %s" cursor.execute(sql, (AsIs(','.join(columns)), tuple(values))) conn.commit() main_database.Database.run_script def execute_update(self, sql, vars=None): conn = self.get_connection() success = True try: cursor = conn.cursor() cursor.execute(sql, vars) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) success = False return success def value_exists(self, table, field, value): exists = False conn = self.get_connection() try: cursor = conn.cursor() sql = "SELECT * FROM %s WHERE %s = %s " vars = (AsIs(table), AsIs(field), value) cursor.execute(sql, vars) results = cursor.fetchall() self.pool.putconn(conn) if len(results) > 0: exists = True except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return exists def delete_table(self, table): '''Deletes the table from the database table: Thr name of the table to delete Returns: True - if the table was deleted, otherwise False. ''' conn = self.get_connection() sql = "DROP TABLE public.{}".format(table) try: cursor = conn.cursor() cursor.execute(sql) conn.commit() self.pool.putconn(conn) return True except Exception as e: self.app.logger.exception( "The table {} could not be deleted using the SQL:\n{}".format( table, sql)) conn.rollback() self.pool.putconn(conn) return False def remove_columns(self, table, columns): '''Deletes columns from the specified table Args: columns(list[str]): A list of column names to delete Returns: True - if the columns were deleted, otherwise False ''' conn = self.get_connection() results = [] sql = "" try: cursor = conn.cursor() for col in columns: sql = "ALTER TABLE public.{} DROP COLUMN {}".format(table, col) cursor.execute(sql) conn.commit() self.pool.putconn(conn) return True except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return False def add_columns(self, table, columns): '''Adds columns to to the specified table Args: columns: A list of dictionaries with 'name' and 'datatype' (text,integer or double) and optionally default (specifies default value). There is no check on the column name, so it has to be compatible with postgresql Returns: True - if the columns were added, otherwise False ''' conn = self.get_connection() results = [] sql = "" try: cursor = conn.cursor() for col in columns: if col['datatype'] == "double": col['datatype'] = "double precision" sql = "ALTER TABLE public.{} ADD COLUMN {} {}".format( table, col['name'], col['datatype']) if col.get("default"): val = col.get("default") if col['datatype'] == "text": val = "'" + val + "'" sql += " default " + val cursor.execute(sql) conn.commit() self.pool.putconn(conn) return True except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return False def execute_query(self, sql, vars=None): conn = self.get_connection() results = [] try: cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) cursor.execute(sql, vars) results = cursor.fetchall() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) self.pool.putconn(conn) return results def get_sql(self, sql, vars=None): conn = self.get_connection() try: cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) result = cursor.mogrify(sql, vars) self.pool.putconn(conn) return result except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) self.pool.putconn(conn) def delete_by_id(self, table, ids): conn = self.get_connection() try: cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) sql = "DELETE FROM %s WHERE id = ANY (%s)" cursor.execute(sql, (AsIs(table), ids)) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) def execute_delete(self, sql, vars=None): conn = self.get_connection() try: cursor = conn.cursor() cursor.execute(sql, vars) conn.commit() self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return False return True def execute_insert(self, sql, vars=None, ret_id=True): if ret_id: sql += " RETURNING id" conn = self.get_connection() new_id = -1 try: cursor = conn.cursor() cursor.execute(sql, vars) conn.commit() if ret_id: new_id = cursor.fetchone()[0] self.pool.putconn(conn) except Exception as e: self.app.logger.exception( "The SQL could not be run:\n{}".format(sql)) conn.rollback() self.pool.putconn(conn) return new_id def get_tracks(self, track_ids=[], field="track_id", proxy=True): '''Get track info for all the supplies track ids Args: track_ids(list): A list of track_ids (or ids) to retreive field(Optional[str]): The field to query the databse (track_id by default) proxy(Optional[boolean]) If true (default) the the track's proxy url as defined in the config will be returned ''' from app import app conn = self.get_connection() cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) sql = "SELECT id,track_id,url,type,short_label,color FROM tracks WHERE %s = ANY(%s)" cursor.execute(sql, (AsIs(field), track_ids)) results = cursor.fetchall() self.pool.putconn(conn) if not results: return [] t_p = app.config.get("TRACK_PROXIES") if t_p and proxy: for item in results: for p in t_p: item['url'] = item['url'].replace(p, t_p[p]) return results def insert_dicts_into_table(self, dicts, table): conn = self.get_connection() try: cursor = conn.cursor() for item in dicts: sql = "INSERT INTO {} (%s) VALUES %s".format(table) columns = item.keys() values = [item[column] for column in columns] cursor.execute(sql, (AsIs(','.join(columns)), tuple(values))) conn.commit() cursor.close() self.pool.putconn(conn) except Exception as e: conn.rollback() self.pool.putconn(conn) raise def update_table_with_dicts(self, dicts, table): '''Updates the table with with the values in the supplied dictionaries (the keys being the column names) Args: dicts (list[dict]): A list of dictionaries containing key(column name) to value (new value). Each dictionary should also contain an id key with the id of the row. Returns: True if the update was successful, otherwise False. ''' conn = self.get_connection() sql = "" complete = True try: cursor = conn.cursor() for item in dicts: id = item['id'] del item['id'] sql = "UPDATE {} SET (%s) = %s WHERE id= %s".format(table) columns = item.keys() values = [item[column] for column in columns] cursor.execute(sql, (AsIs(','.join(columns)), tuple(values), id)) conn.commit() cursor.close() except: self.app.logger.exception("update could not be run") conn.rollback() complete = False self.pool.putconn(conn) return complete def insert_dict_into_table(self, dic, table): conn = self.get_connection() try: cursor = conn.cursor() sql = "INSERT INTO {} (%s) VALUES %s RETURNING id".format(table) columns = dic.keys() values = [dic[column] for column in columns] cursor.execute(sql, (AsIs(','.join(columns)), tuple(values))) conn.commit() new_id = cursor.fetchone()[0] cursor.close() self.pool.putconn(conn) return new_id except: self.app.logger.exception("update could not be run") self.pool.putconn(conn) return -1
class Database(): def __init__(self, config, verbose_start=False): self.config = config if verbose_start: print('Connecting to database...'.ljust( LogDecorator.get_ljust_offset()), end='') # get DB parameters try: self.database = config.getProperty('Database', 'name').lower() self.host = config.getProperty('Database', 'host') self.port = config.getProperty('Database', 'port') self.user = config.getProperty('Database', 'user', fallback=None) self.password = config.getProperty('Database', 'password', fallback=None) if self.user is None or self.password is None: # load from credentials file instead credentials = config.getProperty('Database', 'credentials') with open(credentials, 'r') as c: lines = c.readlines() for line in lines: line = line.lstrip().rstrip('\r').rstrip('\n') if line.startswith('#') or line.startswith(';'): continue tokens = line.split('=') if len(tokens) >= 2: idx = line.find('=') + 1 field = tokens[0].strip().lower() if field == 'username': self.user = line[idx:] elif field == 'password': self.password = line[idx:] self.user = self.user.lower() except Exception as e: if verbose_start: LogDecorator.print_status('fail') raise Exception( f'Incomplete database credentials provided in configuration file (message: "{str(e)}").' ) try: self._createConnectionPool() except Exception as e: if verbose_start: LogDecorator.print_status('fail') raise Exception( f'Could not connect to database (message: "{str(e)}").') if verbose_start: LogDecorator.print_status('ok') def _createConnectionPool(self): self.connectionPool = ThreadedConnectionPool( 0, max( 2, self.config.getProperty('Database', 'max_num_connections', type=int, fallback=20) ), # 2 connections are needed as minimum for retrying of execution host=self.host, database=self.database, port=self.port, user=self.user, password=self.password, connect_timeout=10) def canConnect(self): with self._get_connection() as conn: return conn is not None and not conn.closed @contextmanager def _get_connection(self): conn = self.connectionPool.getconn() conn.autocommit = True try: yield conn finally: self.connectionPool.putconn(conn, close=False) def execute(self, query, arguments, numReturn=None): with self._get_connection() as conn: cursor = conn.cursor(cursor_factory=RealDictCursor) # execute statement try: cursor.execute(query, arguments) conn.commit() except Exception as e: if not conn.closed: conn.rollback() # self.connectionPool.putconn(conn, close=False) #TODO: this still causes connection to close conn = self.connectionPool.getconn() # retry execution try: cursor = conn.cursor(cursor_factory=RealDictCursor) cursor.execute(query, arguments) conn.commit() except: if not conn.closed: conn.rollback() print(e) # get results try: returnValues = [] if numReturn is None: # cursor.close() return elif numReturn == 'all': returnValues = cursor.fetchall() # cursor.close() return returnValues else: for _ in range(numReturn): rv = cursor.fetchone() if rv is None: return returnValues returnValues.append(rv) # cursor.close() return returnValues except Exception as e: print(e) def insert(self, query, values, numReturn=None): with self._get_connection() as conn: cursor = conn.cursor() try: execute_values(cursor, query, values) conn.commit() except Exception as e: if not conn.closed: conn.rollback() # cursor.close() # retry execution conn = self.connectionPool.getconn() try: cursor = conn.cursor(cursor_factory=RealDictCursor) execute_values(cursor, query, values) conn.commit() except Exception as e: if not conn.closed: conn.rollback() print(e) # get results try: returnValues = [] if numReturn is None: # cursor.close() return elif numReturn == 'all': returnValues = cursor.fetchall() # cursor.close() return returnValues else: for _ in range(numReturn): rv = cursor.fetchone() if rv is None: return returnValues returnValues.append(rv) # cursor.close() return returnValues except Exception as e: print(e)
class Database(DatabaseInterface): _databases = {} _connpool = None _list_cache = None _list_cache_timestamp = None _version_cache = {} flavor = Flavor(ilike=True) def __new__(cls, name='template1'): if name in cls._databases: return cls._databases[name] return DatabaseInterface.__new__(cls, name=name) def __init__(self, name='template1'): super(Database, self).__init__(name=name) self._databases.setdefault(name, self) self._search_path = None self._current_user = None @classmethod def dsn(cls, name): uri = parse_uri(config.get('database', 'uri')) assert uri.scheme == 'postgresql' host = uri.hostname and "host=%s" % uri.hostname or '' port = uri.port and "port=%s" % uri.port or '' name = "dbname=%s" % name user = uri.username and "user=%s" % uri.username or '' password = ("password=%s" % urllib.unquote_plus(uri.password) if uri.password else '') return '%s %s %s %s %s' % (host, port, name, user, password) def connect(self): if self._connpool is not None: return self logger.info('connect to "%s"', self.name) minconn = config.getint('database', 'minconn', default=1) maxconn = config.getint('database', 'maxconn', default=64) self._connpool = ThreadedConnectionPool( minconn, maxconn, self.dsn(self.name)) return self def get_connection(self, autocommit=False, readonly=False): if self._connpool is None: self.connect() conn = self._connpool.getconn() if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) else: conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ) if readonly: cursor = conn.cursor() cursor.execute('SET TRANSACTION READ ONLY') conn.cursor_factory = PerfCursor return conn def put_connection(self, connection, close=False): self._connpool.putconn(connection, close=close) def close(self): if self._connpool is None: return self._connpool.closeall() self._connpool = None @classmethod def create(cls, connection, database_name): cursor = connection.cursor() cursor.execute('CREATE DATABASE "' + database_name + '" ' 'TEMPLATE template0 ENCODING \'unicode\'') connection.commit() cls._list_cache = None def drop(self, connection, database_name): cursor = connection.cursor() cursor.execute('DROP DATABASE "' + database_name + '"') Database._list_cache = None def get_version(self, connection): if self.name not in self._version_cache: cursor = connection.cursor() cursor.execute('SELECT version()') version, = cursor.fetchone() self._version_cache[self.name] = tuple(map(int, RE_VERSION.search(version).groups())) return self._version_cache[self.name] @staticmethod def dump(database_name): from trytond.tools import exec_command_pipe cmd = ['pg_dump', '--format=c', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: # if db_password is set in configuration we should pass # an environment variable PGPASSWORD to our subprocess # see libpg documentation env['PGPASSWORD'] = uri.password cmd.append(database_name) pipe = exec_command_pipe(*tuple(cmd), env=env) pipe.stdin.close() data = pipe.stdout.read() res = pipe.wait() if res: raise Exception('Couldn\'t dump database!') return data @staticmethod def restore(database_name, data): from trytond.tools import exec_command_pipe database = Database().connect() connection = database.get_connection(autocommit=True) database.create(connection, database_name) database.close() cmd = ['pg_restore', '--no-owner'] env = {} uri = parse_uri(config.get('database', 'uri')) if uri.username: cmd.append('--username='******'--host=' + uri.hostname) if uri.port: cmd.append('--port=' + str(uri.port)) if uri.password: env['PGPASSWORD'] = uri.password cmd.append('--dbname=' + database_name) args2 = tuple(cmd) if os.name == "nt": tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam() with open(tmpfile, 'wb') as fp: fp.write(data) args2 = list(args2) args2.append(' ' + tmpfile) args2 = tuple(args2) pipe = exec_command_pipe(*args2, env=env) if not os.name == "nt": pipe.stdin.write(data) pipe.stdin.close() res = pipe.wait() if res: raise Exception('Couldn\'t restore database') database = Database(database_name).connect() cursor = database.get_connection().cursor() if not database.test(): cursor.close() database.close() raise Exception('Couldn\'t restore database!') cursor.close() database.close() Database._list_cache = None return True def list(self): now = time.time() timeout = config.getint('session', 'timeout') res = Database._list_cache if res and abs(Database._list_cache_timestamp - now) < timeout: return res connection = self.get_connection() cursor = connection.cursor() cursor.execute('SELECT datname FROM pg_database ' 'WHERE datistemplate = false ORDER BY datname') res = [] for db_name, in cursor: try: with connect(self.dsn(db_name)) as conn: if self._test(conn): res.append(db_name) except Exception: continue self.put_connection(connection) Database._list_cache = res Database._list_cache_timestamp = now return res def init(self): from trytond.modules import get_module_info connection = self.get_connection() cursor = connection.cursor() sql_file = os.path.join(os.path.dirname(__file__), 'init.sql') with open(sql_file) as fp: for line in fp.read().split(';'): if (len(line) > 0) and (not line.isspace()): cursor.execute(line) for module in ('ir', 'res'): state = 'uninstalled' if module in ('ir', 'res'): state = 'to install' info = get_module_info(module) cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')') module_id = cursor.fetchone()[0] cursor.execute('INSERT INTO ir_module ' '(id, create_uid, create_date, name, state) ' 'VALUES (%s, %s, now(), %s, %s)', (module_id, 0, module, state)) for dependency in info.get('depends', []): cursor.execute('INSERT INTO ir_module_dependency ' '(create_uid, create_date, module, name) ' 'VALUES (%s, now(), %s, %s)', (0, module_id, dependency)) connection.commit() self.put_connection(connection) def test(self): connection = self.get_connection() is_tryton_database = self._test(connection) self.put_connection(connection) return is_tryton_database @classmethod def _test(cls, connection): cursor = connection.cursor() cursor.execute('SELECT 1 FROM information_schema.tables ' 'WHERE table_name IN %s', (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu', 'res_user', 'res_group', 'ir_module', 'ir_module_dependency', 'ir_translation', 'ir_lang'),)) return len(cursor.fetchall()) != 0 def nextid(self, connection, table): cursor = connection.cursor() cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')") return cursor.fetchone()[0] def setnextid(self, connection, table, value): cursor = connection.cursor() cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value) def currid(self, connection, table): cursor = connection.cursor() cursor.execute('SELECT last_value FROM "' + table + '_id_seq"') return cursor.fetchone()[0] def lock(self, connection, table): cursor = connection.cursor() cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table) def has_constraint(self): return True def has_multirow_insert(self): return True def get_table_schema(self, connection, table_name): cursor = connection.cursor() for schema in self.search_path: cursor.execute('SELECT 1 ' 'FROM information_schema.tables ' 'WHERE table_name = %s AND table_schema = %s', (table_name, schema)) if cursor.rowcount: return schema @property def current_user(self): if self._current_user is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SELECT current_user') self._current_user = cursor.fetchone()[0] finally: self.put_connection(connection) return self._current_user @property def search_path(self): if self._search_path is None: connection = self.get_connection() try: cursor = connection.cursor() cursor.execute('SHOW search_path') path, = cursor.fetchone() special_values = { 'user': self.current_user, } self._search_path = [ unescape_quote(replace_special_values( p.strip(), **special_values)) for p in path.split(',')] finally: self.put_connection(connection) return self._search_path