Beispiel #1
0
	def putconn(self, conn):
		"""
		Returns connection back to pool.

		"""

		#calledBy = traceback.extract_stack()[-2]
		#logging.info("PUTCONN - FILE: " + calledBy[0] + ", LINE: " + str(calledBy[1]) + ", METHOD: " + calledBy[2])

		ThreadedConnectionPool.putconn(self, conn)
class Database:
    def __init__(self, connect_param):
        self.__connect_param = connect_param
        self.__pool = ThreadedConnectionPool(0, 10, self.__connect_param)
        # get cursor and test it
        # cur = self.cursor()
        # cur.execute('SHOW transaction_read_only')
        # standby = cur.fetchone()
        # cur.close()

    def get_connection(self):
        return self.__pool.getconn()

    def put_connection(self, connection):
        self.__pool.putconn(connection)
Beispiel #3
0
class Database():
    def __init__(self, config):
        logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO)

        self._pool = ThreadedConnectionPool(1, 10, 
                database=config['DB_DATABASE'],
                user=config['DB_USER'],
                password=config['DB_PASSWORD'],
                host=config['DB_HOST'],
                async=False)

    def get_connection(self):
        return self._pool.getconn()

    def put_away_connection(self, con):
        self._pool.putconn(con)
Beispiel #4
0
class DB:
    def __init__(self, *args, **kwargs):
        self.pool_params = (args, kwargs)
        self.pool = None

        self.campaigns = Campaigns(self)
        self.worksets = Worksets(self)
        self.tasks = Tasks(self)
        self.labels = Labels(self)
        self.logger = logging.getLogger(__name__)

    def _initialize_pool(self):
        if self.pool is None:
            logger.info("Initializing connection pool.")
            args, kwargs = self.pool_params
            self.pool = ThreadedConnectionPool(
                *args, cursor_factory=RealDictCursor, **kwargs)

    def execute(self, sql):
        with self.transaction() as transactor:
            cursor = transactor.cursor()
            cursor.execute(sql)
            return cursor

    @contextmanager
    def transaction(self):
        """Provides a transactional scope around a series of operations."""
        self._initialize_pool()
        conn = self.pool.getconn()
        try:
            yield conn
            conn.commit()
        except:
            conn.rollback()
            raise
        finally:
            self.pool.putconn(conn)

    @classmethod
    def from_config(cls, config):
        # Copy config as kwargs
        params = {k: v for k, v in config['database'].items()}
        params['minconn'] = params.get('minconn', 1)
        params['maxconn'] = params.get('maxconn', 5)

        return cls(**params)
Beispiel #5
0
class ConnectionPool:  # no test coverage
    """https://gist.github.com/jeorgen/4eea9b9211bafeb18ada"""

    is_setup = False

    def setup(self):
        self.last_seen_process_id = os.getpid()
        self._init()
        self.is_setup = True

    def _init(self):
        self._pool = ThreadedConnectionPool(1,
                                            10,
                                            database=config.db_name,
                                            **credentials)

    def _getconn(self) -> connection:
        current_pid = os.getpid()
        if not (current_pid == self.last_seen_process_id):
            self._init()
            log.debug(
                f"New id is {current_pid}, old id was {self.last_seen_process_id}"
            )
            self.last_seen_process_id = current_pid
        conn = self._pool.getconn()
        return conn

    def _putconn(self, conn: connection):
        return self._pool.putconn(conn)

    def closeall(self):
        self._pool.closeall()

    @contextmanager
    def get_connection(self) -> t.Generator[connection, None, None]:
        try:
            conn = self._getconn()
            yield conn
        finally:
            self._putconn(conn)

    @contextmanager
    def get_cursor(self,
                   commit=False) -> t.Generator[LoggingCursor, None, None]:
        with self.get_connection() as conn:
            cursor = conn.cursor(cursor_factory=LoggingCursor)
            try:
                yield cursor
                if commit:
                    conn.commit()
            finally:
                cursor.close()
Beispiel #6
0
class ConnectionPool(object):
    def __init__(self, conn_params, minconn=5, maxconn=5):
        self._conn_params = conn_params.copy()
        self._conn_params['minconn'] = minconn
        self._conn_params['maxconn'] = maxconn
        self._conn_pool = None

    def initialize(self):
        self._conn_pool = ThreadedConnectionPool(**self._conn_params)

    @contextmanager
    def cursor(self):
        conn = self._conn_pool.getconn()
        cursor = conn.cursor()
        try:
            yield cursor
            conn.commit()
        except Exception:
            conn.rollback()
            raise
        finally:
            self._conn_pool.putconn(conn)
Beispiel #7
0
class PgConnectionPool:

    def __init__(self, *args, min_conns=1, keep_conns=10, max_conns=10,
                 **kwargs):
        self._pool = ThreadedConnectionPool(
            min_conns, max_conns, *args, **kwargs)
        self._keep_conns = keep_conns

    def acquire(self):
        pool = self._pool
        conn = pool.getconn()
        pool.minconn = min(self._keep_conns, len(pool._used))
        return conn

    def release(self, conn):
        self._pool.putconn(conn)

    def close(self):
        if hasattr(self, '_pool'):
            self._pool.closeall()

    __del__ = close
Beispiel #8
0
class PoolWrapper:
    """Exists to provide an acquire method for easy usage.

        pool = PoolWrapper(...)
        with pool.acquire() as conneciton:
            connection.execute(...)
    """
    def __init__(self, max_pool_size: int, *, dsn):
        self._pool = ThreadedConnectionPool(
            1,
            max_pool_size,
            dsn=dsn,
            cursor_factory=RealDictCursor,
        )

    @contextmanager
    def acquire(self):
        try:
            connection = self._pool.getconn()
            yield connection
        finally:
            self._pool.putconn(connection)
Beispiel #9
0
 def from_qid_author(qid: int,
                     author: str,
                     pool: ThreadedConnectionPool = None):
     """
     Retrieve data from database and construct ``Answer`` using it.
     """
     conn = pool.getconn()
     with conn.cursor() as curs:
         curs.execute(
             "SELECT body, upvotes, downvotes "
             "FROM answers "
             "WHERE author=%(username)s AND qid=%(qid)s", {
                 'username': author,
                 'qid': qid
             })
         body, upvotes, downvotes = curs.fetchone()
     pool.putconn(conn)
     return Answer(author=author,
                   body=body,
                   upvotes=upvotes,
                   downvotes=downvotes,
                   qid=qid,
                   pool=pool)
Beispiel #10
0
class ProcessSafePoolManager:
    def __init__(self, *args, **kwargs):
        self.last_seen_process_id = os.getpid()
        self.args = args
        self.kwargs = kwargs
        self._init()

    def _init(self):
        self._pool = ThreadedConnectionPool(*self.args, **self.kwargs)

    def getconn(self):
        current_pid = os.getpid()
        if not (current_pid == self.last_seen_process_id):
            self._init()
            print("New id is %s, old id was %s" %
                  (current_pid, self.last_seen_process_id))
            self.last_seen_process_id = current_pid
        return self._pool.getconn()

    def putconn(self, conn):
        return self._pool.putconn(conn)
Beispiel #11
0
class PostgresThreadPool:
    provides = ['db_connection_pool', 'postgres']
    requires_configured = ['json_settings']

    def __init__(self, settings):
        from psycopg2.pool import ThreadedConnectionPool
        dbsettings = settings['database']
        self.pool = ThreadedConnectionPool(
            minconn=1,
            maxconn=settings['database']['conn_pool_size'],
            database=dbsettings['name'],
            user=dbsettings['username'],
            password=dbsettings['password'],
            host=dbsettings['host'],
            port=dbsettings.get('port')
        )

    def getconn(self):
        return self.pool.getconn()

    def putconn(self, connection):
        return self.pool.putconn(connection)
Beispiel #12
0
class Pool:
    def __init__(self, ):
        self.pool = None

    def init_app(self, app):
        self.pool = ThreadedConnectionPool(minconn=1,
                                           maxconn=10,
                                           dsn=app.config["DATABASE_URL"])
        app.teardown_appcontext(self.return_conn)

    def get_conn(self):
        if 'db_conn' not in g:
            g.db_conn = self.pool.getconn()
        return g.db_conn

    def return_conn(self, x):
        db_conn = g.pop('db_conn', None)
        if db_conn:
            self.pool.putconn(db_conn)

    def unwrap_pg_statement(self, func, *args, **kwargs):
        items = func(*args, **kwargs)
        if type(items) is str:
            return items, []
        elif type(items) is tuple:
            if len(items) == 1:
                return items[0], []
            else:
                return items[0], items[1]

    def execute(self, cursor_method="fetchall"):
        def decorator(func):
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                with self.get_conn() as conn:
                    with conn.cursor(cursor_factory=RealDictCursor) as c:
                        sql, values = self.unwrap_pg_statement(
                            func, *args, **kwargs)
                        c.execute(sql, values)
                        if sql.strip().upper().startswith("SELECT"):
                            if cursor_method == "fetchall":
                                result = [x for x in c.fetchall()]
                            elif cursor_method == "fetchone":
                                result = c.fetchone()
                                # print("SQL", sql)
                                # print("FETCH ONE RESULT", result)
                            if type(result) is list and not len(result):
                                return []
                            return result
                        return True

            return wrapper

        return decorator

    def executemany(self, func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            with self.get_conn() as conn:
                with conn.cursor() as c:
                    sql, values = self.unwrap_pg_statement(
                        func, *args, **kwargs)
                    c.executemany(sql, values)
                    return True

        return wrapper
Beispiel #13
0
class Database:
    def __init__(self, db_config, table_raw=None, max_connections=10):
        from psycopg2.pool import ThreadedConnectionPool

        self.table_raw = table_raw
        try:
            self.pool = ThreadedConnectionPool(
                minconn=1,
                maxconn=max_connections,
                dsn=
                "dbname={db_name} user={db_user} host={db_host} password={db_pass}"
                .format(**db_config))
        except Exception:
            logger.exception("Error in db connection")
            sys.exit(1)

        logger.debug(
            "Connected to database: {host}".format(host=db_config['db_host']))

    @contextmanager
    def getcursor(self, **kwargs):
        conn = self.pool.getconn()
        try:
            yield conn.cursor(**kwargs)
            conn.commit()

        except Exception as e:
            conn.rollback()
            raise e.with_traceback(sys.exc_info()[2])

        finally:
            self.pool.putconn(conn)

    def insert(self, table, data_list, return_cols='id'):
        """
        TODO: rename `id_col` -> `return_col`
        Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and n cols)
        for inserting data then executemany()

        TODO: Is there a limit of length the query can be? If so handle it.
        """
        # Make sure that `data_list` is a list
        if not isinstance(data_list, list):
            data_list = [data_list]
        # Make sure data_list has content
        if len(data_list) == 0:
            # No need to continue
            return []

        # Data in the list must be dicts (just check the first one)
        if not isinstance(data_list[0], dict):
            logger.critical("Data must be a list of dicts")
            # Do not return here, let the exception handle the error that will be thrown when the query runs

        # Make sure return_cols is a list
        if not isinstance(return_cols, list):
            return_cols = [return_cols]
        # Make sure on_conflict_fields has data
        if len(return_cols) == 0 or return_cols[0] is None:
            # No need to continue
            logger.critical("`return_cols` cannot be None/empty")
            # TODO: raise some error here rather then returning None
            return None

        try:
            with self.getcursor() as cur:
                query = "INSERT INTO {table} ({fields}) VALUES {values} RETURNING {return_cols}"\
                        .format(table=table,
                                fields='"{0}"'.format('", "'.join(data_list[0].keys())),
                                values=','.join(['%s'] * len(data_list)),
                                return_cols=', '.join(return_cols),
                                )
                query = cur.mogrify(query,
                                    [tuple(v.values()) for v in data_list])
                cur.execute(query)

                return cur.fetchall()

        except Exception as e:
            logger.exception("Error inserting data")
            logger.debug("Error inserting data: {data}".format(data=data_list))
            raise e.with_traceback(sys.exc_info()[2])

    def upsert(self,
               table,
               data_list,
               on_conflict_fields,
               on_conflict_action='update',
               update_fields=None,
               return_cols='id'):
        """
        Create a bulk upsert statement which is much faster (~6x in tests with 10k & 100k rows and n cols)
        for upserting data then executemany()

        TODO: Is there a limit of length the query can be? If so handle it.
        """
        # Make sure that `data_list` is a list
        if not isinstance(data_list, list):
            data_list = [data_list]
        # Make sure data_list has content
        if len(data_list) == 0:
            # No need to continue
            return []
        # Data in the list must be dicts (just check the first one)
        if not isinstance(data_list[0], dict):
            logger.critical("Data must be a list of dicts")
            # TODO: raise some error here rather then returning None
            return None

        # Make sure on_conflict_fields is a list
        if not isinstance(on_conflict_fields, list):
            on_conflict_fields = [on_conflict_fields]
        # Make sure on_conflict_fields has data
        if len(on_conflict_fields) == 0 or on_conflict_fields[0] is None:
            # No need to continue
            logger.critical("Must pass in `on_conflict_fields` argument")
            # TODO: raise some error here rather then returning None
            return None

        # Make sure return_cols is a list
        if not isinstance(return_cols, list):
            return_cols = [return_cols]
        # Make sure on_conflict_fields has data
        if len(return_cols) == 0 or return_cols[0] is None:
            # No need to continue
            logger.critical("`return_cols` cannot be None/empty")
            # TODO: raise some error here rather then returning None
            return None

        # Make sure update_fields is a list/valid
        if on_conflict_action == 'update':
            if not isinstance(update_fields, list):
                update_fields = [update_fields]
            # If noting is passed in, set `update_fields` to all (data_list-on_conflict_fields)
            if len(update_fields) == 0 or update_fields[0] is None:
                update_fields = list(
                    set(data_list[0].keys()) - set(on_conflict_fields))
                # If update_fields is empty here that could only mean that all fields are set as conflict_fields
                if len(update_fields) == 0:
                    logger.critical(
                        "Not all the fields can be `on_conflict_fields` when doing an update"
                    )
                    # TODO: raise some error here rather then returning None
                    return None

            # If everything is good to go with the update fields
            fields_update_tmp = []
            for key in data_list[0].keys():
                fields_update_tmp.append('"{0}"="excluded"."{0}"'.format(key))
            conflict_action_sql = 'UPDATE SET {update_fields}'\
                                  .format(update_fields=', '.join(fields_update_tmp))
        else:
            # Do nothing on conflict
            conflict_action_sql = 'NOTHING'

        try:
            with self.getcursor() as cur:
                query = """INSERT INTO {table} ({insert_fields})
                           SELECT {values}
                           ON CONFLICT ({on_conflict_fields}) DO
                           {conflict_action_sql}
                           RETURNING {return_cols}
                        """.format(
                    table=table,
                    insert_fields='"{0}"'.format('", "'.join(
                        data_list[0].keys())),
                    values=','.join(['unnest(%s)'] * len(data_list[0])),
                    on_conflict_fields=', '.join(on_conflict_fields),
                    conflict_action_sql=conflict_action_sql,
                    return_cols=', '.join(return_cols),
                )
                # Get all the values for each row and create a lists of lists
                values = [list(v.values()) for v in data_list]
                # Transpose list of lists
                values = list(map(list, zip(*values)))
                query = cur.mogrify(query, values)

                cur.execute(query)

                return cur.fetchall()

        except Exception as e:
            logger.exception("Error inserting data")
            logger.debug("Error inserting data: {data}".format(data=data_list))
            raise e.with_traceback(sys.exc_info()[2])

    def update(self, table, data_list, matched_field=None, return_cols='id'):
        """
        Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and 4 cols)
        for inserting data then executemany()

        TODO: Is there a limit of length the query can be? If so handle it.
        """
        if matched_field is None:
            # Assume the id field
            logger.info("Matched field not defined, assuming the `id` field")
            matched_field = 'id'

        # Make sure that `data_list` is a list
        if not isinstance(data_list, list):
            data_list = [data_list]

        if len(data_list) == 0:
            # No need to continue
            return []

        # Data in the list must be dicts (just check the first one)
        if not isinstance(data_list[0], dict):
            logger.critical("Data must be a list of dicts")
            # Do not return here, let the exception handle the error that will be thrown when the query runs

        try:
            with self.getcursor() as cur:
                query_list = []
                # TODO: change to return data from the database, not just what you passed in
                return_list = []
                for row in data_list:
                    if row.get(matched_field) is None:
                        logger.debug(
                            "Cannot update row. Missing field {field} in data {data}"
                            .format(field=matched_field, data=row))
                        logger.error(
                            "Cannot update row. Missing field {field} in data".
                            format(field=matched_field))
                        continue

                    # Pull matched_value from data to be updated and remove that key
                    matched_value = row.get(matched_field)
                    del row[matched_field]

                    query = "UPDATE {table} SET {data} WHERE {matched_field}=%s RETURNING {return_cols}"\
                            .format(table=table,
                                    data=','.join("%s=%%s" % u for u in row.keys()),
                                    matched_field=matched_field,
                                    return_cols=return_cols
                                    )
                    values = list(row.values())
                    values.append(matched_value)

                    query = cur.mogrify(query, values)
                    query_list.append(query)
                    return_list.append(matched_value)

                finial_query = b';'.join(query_list)
                cur.execute(finial_query)

                return return_list

        except Exception as e:
            logger.exception("Error updating data")
            logger.debug("Error updating data: {data}".format(data=data_list))
            raise e.with_traceback(sys.exc_info()[2])
class MT_Terrain():
    def __init__(self, n_querier, n_calc, working_area, dataset, sample_ratio,
                 plot):
        self.workers_query_result_q = mp.Queue()
        DSN = "postgresql://*****:*****@192.168.184.102/postgres"
        self.tcp = ThreadedConnectionPool(1, 100, DSN)
        conn = self.tcp.getconn()
        cur = conn.cursor()
        self.querier = []
        self.calc = []
        self.go = mp.Value('b', True)
        self.dataset = dataset
        self.sample_ratio = sample_ratio
        self.working_area = working_area
        self.plot = plot
        self.processed_link = mp.Value('i', 0)
        self._counter_lock = mp.Lock()
        tf = terrain_RF(cur=cur,
                        dataset=self.dataset,
                        working_area=self.working_area)
        buildings = tf.get_buildings()
        self.n_buildings = len(buildings) / sample_ratio
        self.sim_name = "%s_%d_%d_" % (
            (self.dataset, self.n_buildings, time.time()))
        shuffle(buildings)
        buildings = buildings[:self.n_buildings]
        self.write_node_latlon(buildings)
        gid, h, centroid_x, centroid_y = zip(*buildings)
        self.dict_h = dict(zip(gid, h))
        self.tcp.putconn(conn, close=True)
        buildings_pair = set(itertools.combinations(gid, 2))
        self.link_filename = "../data/%s_links.csv" % (self.sim_name)
        self.tot_link = len(buildings_pair)
        print "%d links left to estimate" % len(buildings_pair)
        chunk_size = self.tot_link / n_querier
        chunks = list(chunked(buildings_pair, chunk_size))
        with open(self.link_filename + "_0", 'a') as fl:
            print >> fl, "b1,b2,status,loss,status_downscale,loss_downscale, status_srtm, loss_srtm"
        self.start_time = time.time()
        for i in range(n_querier):
            t = mp.Process(target=self.queryWorker, args=[i, chunks[i]])
            self.querier.append(t)
            t.daemon = True
            t.start()
        for i in range(n_calc):
            t = mp.Process(target=self.calcWorker, args=[i])
            self.calc.append(t)
            t.daemon = True
            t.start()
        t = mp.Process(target=self.monitor)
        t.daemon = True
        t.start()
        [self.querier[i].join() for i in range(n_querier)]
        print "Finished Query"
        self.go.value = False
        print "Set false go"
        [self.calc[i].join() for i in range(n_calc)]

    def write_node_latlon(self, buildings):
        gid, h, centroid_x, centroid_y = zip(*buildings)
        coords = zip(gid, centroid_x, centroid_y)
        latlong_filename = "../data/%s_latlong.csv" % (self.sim_name)
        with open(latlong_filename, 'w') as f:
            print >> f, "id,lat,lon"
            for l in coords:
                print >> f, "%s,%f,%f" % (l)

    def monitor(self):
        time.sleep(1)
        widgets = [
            'Test: ',
            progressbar.Percentage(),
            ' ',
            progressbar.Bar(marker='0', left='[', right=']'),
            ' ',
            progressbar.ETA(),
            ' ',
            progressbar.FileTransferSpeed(unit="Link"),
        ]
        bar = progressbar.ProgressBar(widgets=widgets,
                                      max_value=self.tot_link).start()
        while (self.go.value):
            bar.update(self.processed_link.value, force=True)
            time.sleep(1)
        bar.finish()

    def queryWorker(self, worker_id, buildings_pairs):
        conn = self.tcp.getconn()
        cur = conn.cursor()
        tf = terrain_RF(cur=cur,
                        dataset=self.dataset,
                        working_area=self.working_area)
        tf_srtm = terrain_RF(cur=cur,
                             dataset=self.dataset + "_srtm",
                             working_area=self.working_area)
        for buildings_pair in buildings_pairs:
            id1 = buildings_pair[0]
            id2 = buildings_pair[1]
            try:
                profile = tf.profile_osm(id1, id2)
            except ProfileException:
                profile = None
            try:
                profile_srtm = tf_srtm.profile_osm(id1, id2)
            except ProfileException:
                profile_srtm = None
            result = {
                "id1": id1,
                "id2": id2,
                "profile": profile,
                "profile_srtm": profile_srtm,
                "p1": (0.0, self.dict_h[id1]),
                "p2": (tf.distance(id1, id2), self.dict_h[id2])
            }
            self.workers_query_result_q.put(result)
        self.tcp.putconn(conn, close=True)

    def calcWorker(self, worker_id):
        while (self.go.value):
            # Take ORDER
            if self.workers_query_result_q.qsize() > 3:
                print "Warning: the queue is containing %d elements" % (
                    self.workers_query_result_q.qsize())
            order = self.workers_query_result_q.get()
            profile = order["profile"]
            profile_srtm = order["profile_srtm"]
            id1 = order["id1"]
            id2 = order["id2"]
            # Normal profile
            try:
                link = Link(profile)
                loss, status = link.loss_calculator()

            except (ZeroDivisionError, ProfileException), e:
                loss = 0
                status = -1
            # Downscaled profile
            try:
                link_ds = Link(profile)
                loss_ds, status_ds = link_ds.loss_calculator(downscale=3)
            except (ZeroDivisionError, ProfileException), e:
                loss_ds = 0
                status_ds = -1

            # SRTM Profile
            try:
                link_srtm = Link(profile_srtm)
                loss_srtm, status_srtm = link_srtm.loss_calculator()
            except ProfileException, e:
                status_srtm = 1
                loss_srtm = 0
Beispiel #15
0
class BatchCopy(object):
    def __init__(self,
                 dw_conf,
                 event_path,
                 worker_id,
                 table_names,
                 redis_conf,
                 logger,
                 sep='\x02'):
        self._tables = {}
        self._pool = ThreadedConnectionPool(1,
                                            2,
                                            host=dw_conf.get("host"),
                                            port=dw_conf.get("port"),
                                            database=dw_conf.get("database"),
                                            user=dw_conf.get("user"),
                                            password=dw_conf.get("password"))
        self._redis = redis_conf
        self._logger = logger
        self._sep = sep
        cur = self._pool.getconn().cursor(cursor_factory=RealDictCursor)
        for table_name in table_names:
            worker_dir = "%s/%s" % (event_path, worker_id)
            if not os.path.exists(worker_dir):
                os.makedirs(worker_dir)
            file_path = "%s/%s.udw" % (worker_dir, table_name)
            self._tables[table_name] = {
                "columns": {},
                "path": file_path,
                "file": open(file_path, 'a+'),
                "data": []
            }
            sql = "SELECT ordinal_position, column_name FROM information_schema.columns " \
                  "WHERE table_name = '%s'" % table_name
            cur.execute(sql)
            columns = cur.fetchall()
            if len(columns) == 0:
                self._logger.error(table_name)
                continue
            for column in columns:
                self._tables[table_name]["columns"][
                    column["column_name"]] = column["ordinal_position"]
        cur.close()

    def json2copy(self, event):
        values = []
        hash_values = {}
        for k, v in event.value.items():
            if type(v) == int:
                hash_values[self._tables[event.key]["columns"][k]] = str(v)
            if type(v) == str:
                hash_values[self._tables[event.key]["columns"][k]] = v
            if type(v) == list or type(v) == dict:
                hash_values[self._tables[event.key]["columns"]
                            [k]] = "'%s'" % str(v)
        range_stop = len(hash_values) + 1
        for i in range(1, range_stop):
            values.append(hash_values[i])
        return self._sep.join(values)

    def writelines(self, events):
        for event in events:
            try:
                row = self.json2copy(event)
            except Exception as e:
                self._logger.error(e)
                self._logger.error(self._tables[event.key])
            else:
                self._tables[event.key]["data"].append(row + '\n')
        for v in self._tables.values():
            if not isinstance(v, dict):
                continue
            v["file"].writelines(v["data"])
            del v["data"][:]

    def redis_sieve(self, db, values):
        rc = RedisConnector(self._redis, db)
        values.seek(0)
        values = values.read().split('\n')[:-1]
        return rc.multi_add(values)

    def copy_sink(self):
        for k, v in self._tables.items():
            result = self.redis_sieve(k, v['file'])
            if isinstance(result, list):
                result = '\n'.join(result)
                self._logger.info('duplicate data')
                v['file'].close()
                os.remove(v['path'])
                v['file'] = open(v['path'], 'a+')
                v['file'].write(result)
            # 无重复时
            v["file"].seek(0)
            try:
                conn = self._pool.getconn()
                cur = conn.cursor(cursor_factory=RealDictCursor)
                cur.copy_from(v["file"], k, sep=self._sep)
                conn.commit()
            # 数据类型错误或缺少字段时为DataError
            except psycopg2.DataError as e:
                self._logger.error(e.message)
                v["file"].seek(0)
                with open(
                        '/data/log/unresolved_data/cleaned#' + k + '#' +
                        datetime.now().strftime("%Y%m%d%H%M%S"),
                        'w') as err_log:
                    err_log.write(v["file"].read())
            # 数据库响应速度慢,报错返回失败,返回为空
            except psycopg2.DatabaseError as e:
                self._logger.error(e)
                v["file"].seek(0)
                with open(
                        '/data/log/unresolved_data/cleaned#' + k + '#' +
                        datetime.now().strftime("%Y%m%d%H%M%S"),
                        'w') as interrupted:
                    interrupted.write(v["file"].read())
                sleep(300)
            finally:
                cur.close()
                self._pool.putconn(conn)
        self.flush_copy()

    def flush_copy(self):
        for v in self._tables.values():
            if not isinstance(v, dict):
                continue
            v["file"].close()
            os.remove(v["path"])
            v["file"] = open(v["path"], 'a+')

    def close(self):
        for v in self._tables.values():
            if not isinstance(v, dict):
                continue
            v["file"].close()
            os.remove(v["path"])
            if not self._pool.closed:
                self._pool.closeall()
Beispiel #16
0
    else:
        cur.execute(sql, (ids,))


_original_hosts = []
cnx = cnxpool.getconn()
with cnx.cursor(cursor_factory=RealDictCursor) as cursor:
    cursor.execute(r"""SELECT
                    original_host,
                    display_name as original_host_display_name,
                    region as original_host_region
                FROM findopendata.original_hosts
                WHERE enabled
                ORDER BY display_name""")
    _original_hosts = [row for row in cursor.fetchall()]
cnxpool.putconn(cnx)


@app.route('/api/original-hosts', methods=['GET'])
def original_hosts():
    return jsonify(_original_hosts)

_data_formats = []
cnx = cnxpool.getconn()
with cnx.cursor(cursor_factory=RealDictCursor) as cursor:
    cursor.execute(r"""SELECT DISTINCT format
                   FROM findopendata.package_files
                   WHERE format != ''
                   """)
    format_results = cursor.fetchall()
    pre_filters = ["XLS", "XLSX", "XSL", "ZIP", "PDF", "ODS", "CSV", "TEXT", "JSON", "XML", "API", "HTML", "TXT", "DOCX", "PPTX", "TAR"]
Beispiel #17
0
class Database:

    def __init__(self, db_config, table_raw=None, max_connections=10):
        from psycopg2.pool import ThreadedConnectionPool

        self.table_raw = table_raw
        try:
            self.pool = ThreadedConnectionPool(minconn=1,
                                               maxconn=max_connections,
                                               dsn="dbname={db_name} user={db_user} host={db_host} password={db_pass}"
                                                   .format(**db_config))
        except Exception:
            logger.exception("Error in db connection")
            sys.exit(1)

        logger.debug("Connected to database: {host}".format(host=db_config['db_host']))

    @contextmanager
    def getcursor(self):
        conn = self.pool.getconn()
        try:
            yield conn.cursor()
            conn.commit()

        except Exception as e:
            conn.rollback()
            raise e.with_traceback(sys.exc_info()[2])

        finally:
            self.pool.putconn(conn)

    def insert(self, table, data_list, id_col='id'):
        """
        TODO: rename `id_col` -> `return_col`
        Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and 4 cols)
        for inserting data then executemany()

        TODO: Is there a limit of length the query can be? If so handle it.
        """
        # Make sure that `data_list` is a list
        if not isinstance(data_list, list):
            data_list = [data_list]

        if len(data_list) == 0:
            # No need to continue
            return []

        # Data in the list must be dicts (just check the first one)
        if not isinstance(data_list[0], dict):
            logger.critical("Data must be a list of dicts")
            # Do not return here, let the exception handle the error that will be thrown when the query runs

        try:
            with self.getcursor() as cur:
                query = "INSERT INTO {table} ({fields}) VALUES {values} RETURNING {id_col}"\
                        .format(table=table,
                                fields='{0}{1}{0}'.format('"', '", "'.join(data_list[0].keys())),
                                values=','.join(['%s'] * len(data_list)),
                                id_col=id_col
                                )
                query = cur.mogrify(query, [tuple(v.values()) for v in data_list])
                cur.execute(query)

                return [t[0] for t in cur.fetchall()]

        except Exception as e:
            logger.exception("Error inserting data")
            logger.debug("Error inserting data: {data}".format(data=data_list))
            raise e.with_traceback(sys.exc_info()[2])

    def update(self, table, data_list, matched_field=None, return_col='id'):
        """
        Create a bulk insert statement which is much faster (~2x in tests with 10k & 100k rows and 4 cols)
        for inserting data then executemany()

        TODO: Is there a limit of length the query can be? If so handle it.
        """
        if matched_field is None:
            # Assume the id field
            logger.info("Matched field not defined, assuming the `id` field")
            matched_field = 'id'

        # Make sure that `data_list` is a list
        if not isinstance(data_list, list):
            data_list = [data_list]

        if len(data_list) == 0:
            # No need to continue
            return []

        # Data in the list must be dicts (just check the first one)
        if not isinstance(data_list[0], dict):
            logger.critical("Data must be a list of dicts")
            # Do not return here, let the exception handle the error that will be thrown when the query runs

        try:
            with self.getcursor() as cur:
                query_list = []
                # TODO: change to return data from the database, not just what you passed in
                return_list = []
                for row in data_list:
                    if row.get(matched_field) is None:
                        logger.debug("Cannot update row. Missing field {field} in data {data}"
                                     .format(field=matched_field, data=row))
                        logger.error("Cannot update row. Missing field {field} in data".format(field=matched_field))
                        continue

                    # Pull matched_value from data to be updated and remove that key
                    matched_value = row.get(matched_field)
                    del row[matched_field]

                    query = "UPDATE {table} SET {data} WHERE {matched_field}=%s RETURNING {return_col}"\
                            .format(table=table,
                                    data=','.join("%s=%%s" % u for u in row.keys()),
                                    matched_field=matched_field,
                                    return_col=return_col
                                    )
                    values = list(row.values())
                    values.append(matched_value)

                    query = cur.mogrify(query, values)
                    query_list.append(query)
                    return_list.append(matched_value)

                finial_query = b';'.join(query_list)
                cur.execute(finial_query)

                return return_list

        except Exception as e:
            logger.exception("Error updating data")
            logger.debug("Error updating data: {data}".format(data=data_list))
            raise e.with_traceback(sys.exc_info()[2])
Beispiel #18
0
class DbPool(object):
	"""DB class that makes connection transparently. Thread-safe - every
	thread get its own database connection. Not meant to be used directly,
	there is no reason to have more than one instance - global variable Db
	- in this module."""

	def __init__(self, config):
		"""Configures the Db, connection is not created yet.
		@param config: instance of config.NotaryServerConfig."""

		self.host = config.db_host
		self.port = config.db_port
		self.user = config.db_user
		self.password = config.db_password
		self.db_name = config.db_name
		self.min_connections = config.db_min_conn
		self.max_connections = config.db_max_conn

		self.pool = ThreadedConnectionPool(
			minconn = self.min_connections,
			maxconn = self.max_connections,
			host = self.host,
			port = self.port,
			user = self.user,
			password = self.password,
			database = self.db_name)

	def cursor(self, **kwargs):
		"""Creates and returns cursor for current thread's connection.
		Cursor is a "dict" cursor, so you can access the columns by
		names (not just indices), e.g.:

		cursor.execute("SELECT id, name FROM ... WHERE ...", sql_args)
		row = cursor.fetchone()
		id = row['id']
		
		Server-side cursors (named cursors) should be closed explicitly.
		
		@param kwargs: currently string parameter 'name' is supported.
		Named cursors are for server-side cursors, which
		are useful when fetching result of a large query via fetchmany()
		method. See http://initd.org/psycopg/docs/usage.html#server-side-cursors
		"""
		return self.connection().cursor(cursor_factory=DictCursor, **kwargs)
	
	def connection(self):
		"""Return connection for this thread"""
		return self.pool.getconn(id(threading.current_thread()))

	def commit(self):
		"""Commit all the commands in this transaction in this thread's
		connection. If errors (e.g. duplicate key) arose, this will
		cause transaction rollback.
		"""
		self.connection().commit()

	def rollback(self):
		"""Rollback last transaction on this thread's connection"""
		self.connection().rollback()
	
	def	putconn(self):
		"""Put back connection used by this thread. Necessary upon finishing of
		spawned threads, otherwise new threads won't get connection if the pool
		is depleted."""
		conn = self.connection()
		self.pool.putconn(conn, id(threading.current_thread()))
	
	def close(self):
		"""Close connection."""
		self.connection().close()
class GetEachMatchOdds:
    defaultEncoding = 'utf-8'

    def __init__(self, loglevel, db_config_file, section_name):
        self.pool_size = 2
        self.min_db_conn_size = 2
        logging.config.fileConfig('logger_test.conf')
        self.logger = logging.getLogger(loglevel)
        self.config = configparser.ConfigParser()
        self.config.read(db_config_file)
        self.host = self.config[section_name]['host']
        self.port = self.config[section_name]['port']
        self.user = self.config[section_name]['user']
        self.password = self.config[section_name]['password']
        self.database = self.config[section_name]['database']
        self.conn_pool = ThreadedConnectionPool(self.min_db_conn_size,
                                                self.pool_size,
                                                host=self.host,
                                                port=self.port,
                                                user=self.user,
                                                database=self.database)
        self.main_db_conn = db_connecter.DBConnecter(db_config_file,
                                                     section_name)
        self.main_db_conn.connect()
        pass

    def basic_get_match_info_task(self, queue):
        headers = {
            'User-Agent': 'Mozilla/4.0 (compatible; MSIE 5.5; Windows NT)',
            'Referer':
            'https://developer.mozilla.org/en-US/docs/Web/JavaScript'
        }
        euro_sql_str = '''insert into euro_odds
        (match_id,
        result,
        company,
        priodds3,
        priodds1,
        priodds0,
        nowodds3,
        nowodds1,
        nowodds0,
        prichance3,
        prichance1,
        prichance0,
        nowchance3,
        nowchance1,
        nowchance0,
        priyrr,
        nowyrr,
        prikelly3,
        prikelly1,
        prikelly0,
        nowkelly3,
        nowkelly1,
        nowkelly0
        )
        values
        (
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s
        )
        '''
        asia_sql_str = '''insert into asia_odds
        (match_id,
        result,
        company,
        priodds3,
        priconcede,
        priodds0,
        nowodds3,
        nowconcede,
        nowodds0
        )
        values
        (
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s,
        %s
        )
        '''
        threadname = threading.currentThread().getName()
        while True:
            rows = queue.get()
            db_conn = self.conn_pool.getconn()
            curs = db_conn.cursor()
            for row in rows:
                while True:
                    try:
                        self.logger.info(
                            threadname + ":" +
                            str([row[0], row[1], row[2], row[3], row[4]]))
                        match_id = row[0]
                        if row[1] > row[2]:
                            result = 3
                        elif row[1] == row[2]:
                            result = 1
                        else:
                            result = 0

                        if re.search(r'^http:', row[3]):
                            asia_data_url = row[3]
                        else:
                            asia_data_url = 'http:' + row[3]

                        if re.search(r'^http:', row[4]):
                            euro_data_url = row[4]
                        else:
                            euro_data_url = 'http:' + row[4]

                        euro_r = requests.get(euro_data_url,
                                              headers=headers,
                                              timeout=5)
                        euro_r.encoding = euro_r.apparent_encoding
                        sleep(0.5)
                        asia_r = requests.get(asia_data_url,
                                              headers=headers,
                                              timeout=5)
                        asia_r.encoding = asia_r.apparent_encoding
                        _euro_html = html.fromstring(euro_r.text)
                        if (len(
                                _euro_html.xpath(
                                    '//table[@class="pl_table_data"]')) > 0):
                            all_euro_odds = self.parse_euro_new_html(euro_r)
                            all_asia_odds = self.parse_asia_new_html(asia_r)

                            for e_odds in all_euro_odds:
                                curs.execute(
                                    euro_sql_str,
                                    tuple([match_id, result] + e_odds))

                            for a_odds in all_asia_odds:
                                curs.execute(
                                    asia_sql_str,
                                    tuple([match_id, result] + a_odds))

                        else:
                            all_euro_odds = self.parse_euro_old_html(euro_r)
                            all_asia_odds = self.parse_asia_old_html(asia_r)

                            for e_odds in all_euro_odds:
                                curs.execute(
                                    euro_sql_str,
                                    tuple([match_id, result] + e_odds))

                            for a_odds in all_asia_odds:
                                curs.execute(
                                    asia_sql_str,
                                    tuple([match_id, result] + a_odds))

                        db_conn.commit()
                    except GetHtmlFailed as err:
                        self.logger.error(
                            "Get Html Failed: {0}. 2 seconds later retry to get...."
                            .format(err))
                        sleep(2)
                        continue
                    except Exception as err:
                        self.logger.error(
                            "Get Html Failed: {0}. 5 seconds later retry to get...."
                            .format(err))
                        sleep(5)
                        continue
                    break
            self.conn_pool.putconn(db_conn)
            queue.task_done()

    def basic_work(self, group_size, start_offset, end_offset):
        self.logger.info("Start get matches infomation...")
        start_time = datetime.datetime.now()
        queue = Queue.Queue(maxsize=self.pool_size)
        killer = GracefulKiller()
        for i in range(self.pool_size):
            t = threading.Thread(target=self.basic_get_match_info_task,
                                 name='Task-' + str(i),
                                 args=(queue, ))
            t.daemon = True
            t.start()

        current_offset = start_offset
        headers = {
            'User-Agent': 'Mozilla/4.0 (compatible; MSIE 5.5; Windows NT)',
            'Referer':
            'https://developer.mozilla.org/en-US/docs/Web/JavaScript'
        }
        while (current_offset <= end_offset):
            sql_str = 'select id, home_team_score, away_team_score, asia_data_url, euro_data_url from matchs limit %s offset %s'
            self.main_db_conn.select_record(sql_str,
                                            (group_size, current_offset))
            rows = self.main_db_conn.fetchall()
            queue.put(rows)

            if killer.kill_now:
                self.logger.info(
                    "The programing will stop after all task done, and the current offset is: %s",
                    current_offset + group_size)
                break

            current_offset += group_size

        if not killer.kill_now:
            self.logger.info("Get matches info completed!")

        self.main_db_conn.close()
        queue.join()
        end_time = datetime.datetime.now()
        last_time = end_time - start_time
        self.logger.info(
            "Get matches infomation finished, the program lasted for %s seconds",
            last_time.total_seconds())

    def parse_euro_new_html(self, response):
        _html = html.fromstring(response.text)

        company_names = _html.xpath('//span[@class="quancheng"]/text()')
        tables = _html.xpath('//table[@class="pl_table_data"]')

        if (len(_html.xpath('//table[@class="pl_table_data"]')) == 0):
            raise GetHtmlFailed("Can't get Euro odds")

        if (len(company_names) < 1):
            return []

        odds_arr = []
        for tb in tables[3::4]:
            trs = tb.xpath('./tbody/tr')
            pri_odds = trs[0].xpath('./td/text()')
            last_odds = trs[1].xpath('./td/text()')
            odds_arr.append([pri_odds, last_odds])

        chance_arr = []
        for tb in tables[4::4]:
            trs = tb.xpath('./tbody/tr')
            pri_chance = trs[0].xpath('./td/text()')
            last_chance = trs[1].xpath('./td/text()')
            chance_arr.append([pri_chance, last_chance])

        yrr_arr = []
        for tb in tables[5::4]:
            trs = tb.xpath('./tbody/tr')
            pri_yrr = trs[0].xpath('./td/text()')
            last_yrr = trs[1].xpath('./td/text()')
            yrr_arr.append([pri_yrr, last_yrr])

        kelly_arr = []
        for tb in tables[6::4]:
            trs = tb.xpath('./tbody/tr')
            pri_kelly = trs[0].xpath('./td/text()')
            last_kelly = trs[1].xpath('./td/text()')
            kelly_arr.append([pri_kelly, last_kelly])

        all_company_odds = []
        for c, odds, chance, yrr, kelly in zip(company_names, odds_arr,
                                               chance_arr, yrr_arr, kelly_arr):
            all_company_odds.append([c.encode(GetEachMatchOdds.defaultEncoding),odds[0][0], odds[0][1], odds[0][2], \
            odds[1][0], odds[1][1], odds[1][2], chance[0][0], chance[0][1], chance[0][2], chance[1][0], chance[1][1],\
            chance[1][2], yrr[0][0], yrr[1][0], kelly[0][0], kelly[0][1], kelly[0][2], kelly[1][0], kelly[1][1], kelly[1][2]])
        return all_company_odds

    def parse_euro_old_html(self, response):
        _html = html.fromstring(response.text)

        trs = _html.xpath('//table[@id="datatb"]/tr')

        all_company_odds = []

        if (len(_html.xpath('//table[@id="datatb"]')) == 0):
            raise GetHtmlFailed("Can't get Euro odds")

        if (len(trs) < 4):
            return []

        for elem, next_elem in zip(trs[2::2], trs[3::2] + [trs[2]]):
            elem_tds = elem.xpath('./td')
            next_elem_tds = next_elem.xpath('./td')
            if len(elem_tds) > 12:
                if (len(elem_tds[1].xpath('./a/text()')) > 0):
                    company = elem_tds[1].xpath('./a/text()')[0].encode(
                        GetEachMatchOdds.defaultEncoding)
                else:
                    company = elem_tds[1].text.encode(
                        GetEachMatchOdds.defaultEncoding)
                all_company_odds.append([company,  elem_tds[2].text, elem_tds[3].text, elem_tds[4].text, next_elem_tds[0].text, \
                next_elem_tds[1].text, next_elem_tds[2].text, elem_tds[5].text, elem_tds[6].text, elem_tds[7].text, next_elem_tds[3].text, \
                next_elem_tds[4].text, next_elem_tds[5].text, elem_tds[8].text, next_elem_tds[6].text, elem_tds[9].xpath('./span/text()')[0], \
                elem_tds[10].xpath('./span/text()')[0], elem_tds[11].text, next_elem_tds[7].xpath('./span/text()')[0], \
                next_elem_tds[8].xpath('./span/text()')[0], next_elem_tds[9].text])
        return all_company_odds

    def parse_asia_new_html(self, response):
        _html = html.fromstring(response.text)

        company_names = _html.xpath('//span[@class="quancheng"]/text()')

        if (len(_html.xpath('//table[@class="pl_table_data"]')) == 0):
            raise GetHtmlFailed("Can't get Asia odds")

        if (len(company_names) < 1):
            return []

        tables = _html.xpath('//table[@class="pl_table_data"]')

        now_odds_arr = []
        pri_odds_arr = []
        for tb1, tb2 in zip(tables[2::2], tables[3::2]):
            tb1_tds = tb1.xpath('./tbody/tr/td/text()')
            tb2_tds = tb2.xpath('./tbody/tr/td/text()')
            now_odds_arr.append(tb1_tds[0:3])
            pri_odds_arr.append(tb2_tds[0:3])

        all_company_odds = []
        for c, pri_odds, now_odds in zip(company_names, pri_odds_arr,
                                         now_odds_arr):
            if (len(pri_odds) < 3) or (len(now_odds) < 3):
                continue
            all_company_odds.append([c.encode(GetEachMatchOdds.defaultEncoding), pri_odds[0].encode(GetEachMatchOdds.defaultEncoding), \
            pri_odds[1].replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), pri_odds[2].encode(GetEachMatchOdds.defaultEncoding), \
            now_odds[0].replace(u'\u2193', '').replace(u'\u2191', '').encode(GetEachMatchOdds.defaultEncoding), \
            now_odds[1].replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), \
            now_odds[2].replace(u'\u2193', '').replace(u'\u2191', '').encode(GetEachMatchOdds.defaultEncoding)])

        return all_company_odds

    def parse_asia_old_html(self, response):
        _html = html.fromstring(response.text)

        trs = _html.xpath('//table[@id="datatb"]//tr')

        if (len(_html.xpath('//table[@id="datatb"]')) == 0):
            raise GetHtmlFailed("Can't get Asia odds")

        all_company_odds = []
        number_re = re.compile(r'\d\.\d')
        for tr in trs[1:-3]:
            tds = tr.xpath('./td')
            if (len(tds) > 8):
                pri_home_odds = tds[6].text
                pri_away_odds = tds[8].text

                if len(tds[2].xpath('./span/text()')) < 1 or len(
                        tds[4].xpath('./span/text()')) < 1:
                    continue

                new_home_odds = tds[2].xpath('./span/text()')[0]
                new_away_odds = tds[4].xpath('./span/text()')[0]

                if not number_re.search(pri_home_odds) or not number_re.search(pri_away_odds) or not number_re.search(new_home_odds) or \
                        not number_re.search(new_away_odds) or not tds[1].text or not tds[3].text or not tds[7].text:
                    continue

                all_company_odds.append([tds[1].text.replace(u'\ufffd', '').encode(GetEachMatchOdds.defaultEncoding), \
                pri_home_odds, tds[7].text.replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), \
                pri_away_odds, new_home_odds, tds[3].text.replace(u' \xa0', '').encode(GetEachMatchOdds.defaultEncoding), new_away_odds])

        return all_company_odds
class ExampleSearch(object):
    """Class for performing the Example search.

  Example search is the neighborhood search that demonstrates
  how to construct and query a spatial database based on URL
  search string, extract geometries from the result, associate
  various styles with them and return the response back to the client

  Valid Inputs are:
  q=pacific heights
  neighborhood=pacific heights
  """
    def __init__(self):
        """Inits ExampleSearch.

    Initializes the logger "ge_search".
    Initializes templates for kml, json, placemark templates
    for the KML/JSONP output.
    Initializes parameters for establishing a connection to the database.
    """

        self.utils = utils.SearchUtils()
        constants = geconstants.Constants()

        configs = self.utils.GetConfigs(
            os.path.join(geconstants.SEARCH_CONFIGS_DIR, "ExampleSearch.conf"))

        style_template = self.utils.style_template
        self._jsonp_call = self.utils.jsonp_functioncall
        self._geom = """
            <name>%s</name>
            <styleUrl>%s</styleUrl>
            <Snippet>%s</Snippet>
            <description>%s</description>
            %s\
    """
        self._json_geom = """
         {
            "name": "%s",
            "Snippet": "%s",
            "description": "%s",
            %s
         }\
    """

        self._placemark_template = self.utils.placemark_template
        self._kml_template = self.utils.kml_template

        self._json_template = self.utils.json_template
        self._json_placemark_template = self.utils.json_placemark_template

        self._example_query_template = (Template(constants.example_query))

        self.logger = self.utils.logger

        self._user = configs.get("user")
        self._hostname = configs.get("host")
        self._port = configs.get("port")

        self._database = configs.get("databasename")
        if not self._database:
            self._database = constants.defaults.get("example.database")

        self._pool = ThreadedConnectionPool(
            int(configs.get("minimumconnectionpoolsize")),
            int(configs.get("maximumconnectionpoolsize")),
            database=self._database,
            user=self._user,
            host=self._hostname,
            port=int(self._port))

        self._style = style_template.substitute(
            balloonBgColor=configs.get("balloonstyle.bgcolor"),
            balloonTextColor=configs.get("balloonstyle.textcolor"),
            balloonText=configs.get("balloonstyle.text"),
            iconStyleScale=configs.get("iconstyle.scale"),
            iconStyleHref=configs.get("iconstyle.href"),
            lineStyleColor=configs.get("linestyle.color"),
            lineStyleWidth=configs.get("linestyle.width"),
            polyStyleColor=configs.get("polystyle.color"),
            polyStyleColorMode=configs.get("polystyle.colormode"),
            polyStyleFill=configs.get("polystyle.fill"),
            polyStyleOutline=configs.get("polystyle.outline"),
            listStyleHref=configs.get("iconstyle.href"))

    def RunPGSQLQuery(self, query, params):
        """Submits the query to the database and returns tuples.

    Note: variables placeholder must always be %s in query.
    Warning: NEVER use Python string concatenation (+) or string parameters
    interpolation (%) to pass variables to a SQL query string.
    e.g.
      SELECT vs_url FROM vs_table WHERE vs_name = 'default_ge';
      query = "SELECT vs_url FROM vs_table WHERE vs_name = %s"
      parameters = ["default_ge"]

    Args:
      query: SQL SELECT statement.
      params: sequence of parameters to populate into placeholders.
    Returns:
      Results as list of tuples (rows of fields).
    Raises:
      psycopg2.Error/Warning in case of error.
    """
        con = None
        cursor = None

        query_results = []
        query_status = False

        self.logger.debug(
            "Querying the database %s, at port %s, as user %s on"
            "hostname %s" %
            (self._database, self._port, self._user, self._hostname))
        try:
            con = self._pool.getconn()
            if con:
                cursor = con.cursor()
                cursor.execute(query, params)

                for row in cursor:
                    if len(row) == 1:
                        query_results.append(row[0])
                    else:
                        query_results.append(row)
                        query_status = True
        except psycopg2.pool.PoolError as e:
            self.logger.error("Exception while querying the database %s, %s",
                              self._database, e)
            raise exceptions.PoolConnectionException(
                "Pool Error - Unable to get a connection from the pool.")
        except psycopg2.Error as e:
            self.logger.error("Exception while querying the database %s, %s",
                              self._database, e)
        finally:
            if con:
                self._pool.putconn(con)

        return query_status, query_results

    def RunExampleSearch(self, search_query, response_type):
        """Performs a query search on the 'san_francisco_neighborhoods' table.

    Args:
      search_query: the query to be searched, in smallcase.
      response_type: Response type can be KML or JSONP, depending on the client
    Returns:
      tuple containing
      total_example_results: Total number of rows returned from
       querying the database.
      example_results: Query results as a list
    """
        example_results = []

        params = ["%" + entry + "%" for entry in search_query.split(",")]

        accum_func = self.utils.GetAccumFunc(response_type)

        example_query = self._example_query_template.substitute(
            FUNC=accum_func)
        query_status, query_results = self.RunPGSQLQuery(example_query, params)

        total_example_results = len(query_results)

        if query_status:
            for entry in xrange(total_example_results):
                results = {}

                name = query_results[entry][4]
                snippet = query_results[entry][3]
                styleurl = "#placemark_label"
                description = ("The total area in decimal degrees of " +
                               query_results[entry][4] + " is: " +
                               str(query_results[entry][1]) +
                               "<![CDATA[<br/>]]>")
                description += ("The total perimeter in decimal degrees of " +
                                query_results[entry][4] + " is: " +
                                str(query_results[entry][2]))
                geom = str(query_results[entry][0])

                results["name"] = name
                results["snippet"] = snippet
                results["styleurl"] = styleurl
                results["description"] = description
                results["geom"] = geom
                results["geom_type"] = str(query_results[entry][5])

                example_results.append(results)

        return total_example_results, example_results

    def ConstructKMLResponse(self, search_results, original_query):
        """Prepares KML response.

    KML response has the below format:
      <kml>
       <Folder>
       <name/>
       <StyleURL>
             ---
       </StyleURL>
       <Point>
              <coordinates/>
       </Point>
       </Folder>
      </kml>

    Args:
     search_results: Query results from the searchexample database
     original_query: Search query as entered by the user
    Returns:
     kml_response: KML formatted response
    """
        search_placemarks = ""
        kml_response = ""
        lookat_info = ""
        set_first_element_lookat = True

        # folder name should include the query parameter(q) if 'displayKeys'
        # is present in the URL otherwise not.
        if self.display_keys_string:
            folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" %
                           (original_query, str(len(search_results))))
        else:
            folder_name = ("Grouped results:<![CDATA[<br/>]]> (%s)" %
                           (str(len(search_results))))

        fly_to_first_element = str(self.fly_to_first_element).lower() == "true"

        for result in search_results:
            geom = ""
            placemark = ""

            geom = self._geom % (result["name"], result["styleurl"],
                                 result["snippet"], result["description"],
                                 result["geom"])

            # Add <LookAt> for POINT geometric types only.
            # TODO: Check if <LookAt> can be added for
            # LINESTRING and POLYGON types.
            if result["geom_type"] != "POINT":
                set_first_element_lookat = False

            if fly_to_first_element and set_first_element_lookat:
                lookat_info = self.utils.GetLookAtInfo(result["geom"])
                set_first_element_lookat = False

            placemark = self._placemark_template.substitute(geom=geom)
            search_placemarks += placemark

        kml_response = self._kml_template.substitute(
            foldername=folder_name,
            style=self._style,
            lookat=lookat_info,
            placemark=search_placemarks)

        self.logger.info("KML response successfully formatted")

        return kml_response

    def ConstructJSONPResponse(self, search_results, original_query):
        """Prepares JSONP response.

      {
               "Folder": {
                 "name": "Latitude X Longitude Y",
                 "Placemark": {
                    "Point": {
                      "coordinates": "X,Y" } }
                 }
       }
    Args:
     search_results: Query results from the searchexample table
     original_query: Search query as entered by the user
    Returns:
     jsonp_response: JSONP formatted response
    """
        search_placemarks = ""
        search_geoms = ""
        geoms = ""
        json_response = ""
        jsonp_response = ""

        folder_name = ("Grouped results:<![CDATA[<br/>]]>%s (%s)" %
                       (original_query, str(len(search_results))))

        for count, result in enumerate(search_results):
            geom = ""
            geom = self._json_geom % (result["name"], result["snippet"],
                                      result["description"],
                                      result["geom"][1:-1])

            if count < (len(search_results) - 1):
                geom += ","
            geoms += geom

        if len(search_results) == 1:
            search_geoms = geoms
        else:
            search_geoms = "[" + geoms + "]"

        search_placemarks = self._json_placemark_template.substitute(
            geom=search_geoms)

        json_response = self._json_template.substitute(
            foldername=folder_name, json_placemark=search_placemarks)

        # Escape single quotes from json_response.
        json_response = json_response.replace("'", "\\'")

        jsonp_response = self._jsonp_call % (self.f_callback, json_response)

        self.logger.info("JSONP response successfully formatted")

        return jsonp_response

    def HandleSearchRequest(self, environ):
        """Fetches the search tokens from form and performs the example search.

    Args:
     environ: A list of environment variables as supplied by the
      WSGI interface to the example search application interface.
    Returns:
     search_results: A KML/JSONP formatted string which contains search results.
    """
        search_results = ""
        search_status = False

        # Fetch all the attributes provided by the user.
        parameters = self.utils.GetParameters(environ)
        response_type = self.utils.GetResponseType(environ)

        # Retrieve the function call back name for JSONP response.
        self.f_callback = self.utils.GetCallback(parameters)

        original_query = self.utils.GetValue(parameters, "q")

        # Fetch additional query parameters 'flyToFirstElement' and
        # 'displayKeys' from URL.
        self.fly_to_first_element = self.utils.GetValue(
            parameters, "flyToFirstElement")
        self.display_keys_string = self.utils.GetValue(parameters,
                                                       "displayKeys")

        if not original_query:
            # Extract 'neighborhood' parameter from URL
            try:
                form = cgi.FieldStorage(fp=environ["wsgi.input"],
                                        environ=environ)
                original_query = form.getvalue("neighborhood")
            except AttributeError as e:
                self.logger.debug("Error in neighborhood query %s" % e)

        if original_query:
            (search_status,
             search_results) = self.DoSearch(original_query, response_type)
        else:
            self.logger.debug("Empty or incorrect search query received")

        if not search_status:
            folder_name = "No results were returned."
            search_results = self.utils.NoSearchResults(
                folder_name, self._style, response_type, self.f_callback)

        return (search_results, response_type)

    def DoSearch(self, original_query, response_type):
        """Performs the example search and returns the results.

    Args:
     original_query: A string containing the search query as
      entered by the user.
     response_type: Response type can be KML or JSONP, depending on the client.
    Returns:
     tuple containing
     search_status: Whether search could be performed.
     search_results: A KML/JSONP formatted string which contains search results.
    """
        search_status = False

        search_results = ""
        query_results = ""

        total_results = 0

        search_query = original_query.strip().lower()

        if len(search_query.split(",")) > 2:
            self.logger.warning("Extra search parameters ignored:%s" %
                                (",".join(search_query.split(",")[2:])))
            search_query = ",".join(search_query.split(",")[:2])
            original_query = ",".join(original_query.split(",")[:2])

        total_results, query_results = self.RunExampleSearch(
            search_query, response_type)

        self.logger.info("example search returned %s results" % total_results)

        if total_results > 0:
            if response_type == "KML":
                search_results = self.ConstructKMLResponse(
                    query_results, original_query)
                search_status = True
            elif response_type == "JSONP":
                search_results = self.ConstructJSONPResponse(
                    query_results, original_query)
                search_status = True
            else:
                # This condition may not occur,
                # as response_type is either KML or JSONP
                self.logger.debug("Invalid response type %s" % response_type)

        return search_status, search_results

    def __del__(self):
        """Closes the connection pool created in __init__.
    """
        self._pool.closeall()
Beispiel #21
0
class Database:

    def __init__(self, host, port, dbname, dbuser, dbpass, minconn=1, maxconn=1):
        # Thread pool
        self.pool = ThreadedConnectionPool(
            minconn=minconn,
            maxconn=maxconn,
            host=host,
            database=dbname,
            user=dbuser,
            password=dbpass,
            port=port
        )
        # Base connection for initialization
        self.conn = psycopg2.connect(
            host=host,
            database=dbname,
            user=dbuser,
            password=dbpass,
            port=port
        )
        self.curs = self.conn.cursor()

    def initialize(self):
        # Initialize Database, Recreate Tables
        try:
            self.curs.execute("""CREATE TABLE Users (
                                user_id text, balance bigint)""")
        except:
            self.conn.rollback()
            self.curs.execute("""DROP TABLE Users""")
            self.curs.execute("""CREATE TABLE Users (
                                    user_id text, balance bigint)""")
        self.conn.commit()

        try:
            self.curs.execute("""CREATE TABLE Stock (
                                    stock_id text, user_id text, amount bigint)""")
        except:
            self.conn.rollback()
            self.curs.execute("""DROP TABLE Stock""")
            self.curs.execute("""CREATE TABLE Stock (
                                    stock_id text, user_id text, amount bigint)""")
        self.conn.commit()

        try:
            self.curs.execute("""CREATE TABLE PendingTrans (
                                    type text, user_id text, stock_id text, amount bigint, timestamp bigint)""")
        except:
            self.conn.rollback()
            self.curs.execute("""DROP TABLE PendingTrans""")
            self.curs.execute("""CREATE TABLE PendingTrans (
                                    type text, user_id text, stock_id text, amount bigint, timestamp bigint)""")
        self.conn.commit()

        try:
            self.curs.execute("""CREATE TABLE Trigger (
                                    type text, user_id text, stock_id text, amount bigint, trigger bigint)""")
        except:
            self.conn.rollback()
            self.curs.execute("""DROP TABLE Trigger""")
            self.curs.execute("""CREATE TABLE Trigger (
                                    type text, user_id text, stock_id text, amount bigint, trigger bigint)""")
        self.conn.commit()

        print "DB Initialized"

    # Return a Database Connection from the pool
    def get_connection(self):
        connection = self.pool.getconn()
        cursor = connection.cursor()
        return connection, cursor

    def close_connection(self, connection):
        self.pool.putconn(connection)

        # call like: select_record("Users", "id,balance", "id='jim' AND balance=200")
    def select_record(self, values, table, constraints):
        connection, cursor = self.get_connection()

        try:
            command = """SELECT %s FROM %s WHERE %s""" % (values, table, constraints)
            cursor.execute(command)
            connection.commit()
        except Exception as e:
            print 'PG Select error - ' + str(e)

        result = cursor.fetchall()
        self.close_connection(connection)

        # Format to always return a tuple of the single record, with each value.
        if len(result) > 1:
            print 'PG Select returned more than one value.'
            return (None,None)
        elif len(result) == 0:
            return (None,None)
        else:
            return result[0]

    def filter_records(self, values, table, constraints):
        connection, cursor = self.get_connection()

        try:
            command = """SELECT %s FROM %s WHERE %s""" % (values, table, constraints)
            cursor.execute(command)
            connection.commit()
        except Exception as e:
            print 'PG Select error - ' + str(e)

        result = cursor.fetchall()
        self.close_connection(connection)

        # Return array of tuples
        return result

    def insert_record(self, table, columns, values):
        connection, cursor = self.get_connection()

        try:
            command = """INSERT INTO %s (%s) VALUES (%s)""" % (table, columns, values)
            cursor.execute(command)
            connection.commit()
        except Exception as e:
            print 'PG Insert error - ' + str(e)

        self.close_connection(connection)

    def update_record(self, table, values, constraints):
        connection, cursor = self.get_connection()

        try:
            command = """UPDATE %s SET %s WHERE %s""" % (table, values, constraints)
            cursor.execute(command)
            connection.commit()
        except Exception as e:
            print 'PG Update error %s \n table=%s values=%s constraints=%s command=%s' % (str(e), table, values, constraints, command)

        self.close_connection(connection)

    def delete_record(self, table, constraints):
        connection, cursor = self.get_connection()

        try:
            command = """DELETE FROM %s WHERE %s""" % (table, constraints)
            cursor.execute(command)
            connection.commit()
        except Exception as e:
            print 'PG Delete error - ' + str(e)

        self.close_connection(connection)
Beispiel #22
0
class DBUtils:
    __singleton = None
    __lock = Lock()

    QUOTES_UPSERT_SQL = f"""
            INSERT INTO quotes (NAME, CODE, OPENING_PRICE, LATEST_PRICE, QUOTE_CHANGE,
            CHANGE, VOLUME, TURNOVER, AMPLITUDE, TURNOVER_RATE, "PE_ratio", VOLUME_RATIO,
            MAX_PRICE, MIN_PRICE, CLOSING_PRICE, "PB_ratio", MARKET, TIME)
            VALUES 
            ({','.join(['%s'] * 18)})
            ON CONFLICT ON CONSTRAINT quotes_pkey
            DO UPDATE SET latest_price = excluded.latest_price,
                          change = excluded.change,
                          volume = excluded.volume,
                          turnover = excluded.turnover,
                          amplitude = excluded.amplitude,
                          turnover_rate = excluded.turnover_rate,
                          volume_ratio = excluded.volume_ratio,
                          max_price = excluded.max_price,
                          min_price = excluded.min_price
        """

    COMPANIES_UPSERT_SQL = f"""
            INSERT INTO companies (code, name, intro, manage, ssrq, clrq, fxl, fxfy, mgfxj, fxzsz,
            srkpj, srspj, srhsl, srzgj, djzql, wxpszql, mjzjje, zczb)
            VALUES 
            ({','.join(['%s'] * 18)})
            ON CONFLICT  ON CONSTRAINT companies_pkey
            DO UPDATE SET fxl = excluded.fxl,
                            fxfy = excluded.fxfy,
                            mgfxj = excluded.mgfxj,
                            fxzsz = excluded.fxzsz,
                            djzql = excluded.djzql,
                            wxpszql = excluded.wxpszql,
                            zczb = excluded.zczb
        """

    CODES_QUERY_SQL = """
        SELECT DISTINCT code FROM quotes
    """

    MANAGE_QUERY_SQL = """
        SELECT manage FROM companies order by code
    """

    POS_VEC_UPDATE_SQL = """
        UPDATE companies SET pos_vec = %s WHERE code = %s
    """

    CODES_MANAGE_QUERY_SQL = """
        select code, manage from companies 
    """

    POS_VEC_QUERY_SQL = """
        SELECT pos_vec FROM companies WHERE code = %s
    """

    NEED_UPDADE_CODE_SQL = """
        SELECT DISTINCT code FROM quotes EXCEPT SELECT code FROM companies
    """

    MAIN_TARGET_INSERT_SQL = f"""
        INSERT INTO main_target (code, date, jbmgsy, kfmgsy, xsmgsy, mgjzc, mggjj, mgwfply, 
        mgjyxjl, yyzsr, mlr, gsjlr, kfjlr, yyzsrtbzz, gsjlrtbzz, kfjlrtbzz, yyzsrgdhbzz,
        gsjlrgdhbzz, kfjlrgdhbzz, jqjzcsyl, tbjzcsyl, tbzzcsyl, mll, jll, sjsl, yskyysr, jyxjlyysr,
        xsxjlyysr, zzczzl, yszkzzts, chzzts, zcfzl, ldzczfz, ldbl, sdbl) 
        VALUES 
        ({','.join(['%s'] * 35)})
        ON CONFLICT DO NOTHING 
    """

    def __init__(self, database_config):
        self.database_config = database_config
        pool_config = self.database_config['pool']
        conn_config = self.database_config['conn']
        self.conn_pool = ThreadedConnectionPool(minconn=pool_config['min'],
                                                maxconn=pool_config['max'],
                                                **conn_config)

    def closeall(self):
        self.conn_pool.closeall()

    def upsert_quotes(self, quotes):
        """更新行情"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                try:
                    cur.executemany(self.QUOTES_UPSERT_SQL, quotes)
                    conn.commit()
                except errors.UniqueViolation as e:
                    print("当前记录已存在")
                except errors.NotNullViolation as e:
                    print("违反非空约束")
        finally:
            self.conn_pool.putconn(conn)

    def upsert_company(self, company):
        """添加公司信息"""
        conn = self.conn_pool.getconn()
        try:
            value = [
                company['code'], company['name'], company['intro'],
                company['manage'], company['ssrq'], company['clrq'],
                company['fxl'], company['fxfy'], company['mgfxj'],
                company['fxzsz'], company['srkpj'], company['srspj'],
                company['srhsl'], company['srzgj'], company['djzql'],
                company['wxpszql'], company['mjzjje'], company['zczb']
            ]
            with conn.cursor() as cur:
                cur.execute(self.COMPANIES_UPSERT_SQL, value)
                conn.commit()
        except errors.UniqueViolation as e:
            print("当前记录已存在")

        except errors.NotNullViolation as e:
            print("违反非空约束")
        finally:
            self.conn_pool.putconn(conn)

    def get_all_codes(self):
        """获得所有公司代码"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                cur.execute(self.CODES_QUERY_SQL)
                codes = cur.fetchall()
            return codes
        except Exception as e:
            print(e)
        finally:
            self.conn_pool.putconn(conn)

    def get_all_manage(self):
        """获得所有公司经营范围"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                cur.execute(self.MANAGE_QUERY_SQL)
                manage = cur.fetchall()
            return manage
        except Exception as e:
            print(e)
        finally:
            self.conn_pool.putconn(conn)

    def get_all_codes_manage(self):
        """获取公司所有代码和经营范围"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                cur.execute(self.CODES_MANAGE_QUERY_SQL)
                result = cur.fetchall()
            return result
        except Exception as e:
            print(e)
        finally:
            self.conn_pool.putconn(conn)

    def update_company_pos_vec(self, code, pos_vec):
        """设置公司在行业语义空间中的位置"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                cur.execute(self.POS_VEC_UPDATE_SQL, (pos_vec, code))
                conn.commit()
        except Exception as e:
            print(e)
        finally:
            self.conn_pool.putconn(conn)

    def get_pos_vec(self, code):
        """获得语义空间位置向量"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                cur.execute(self.POS_VEC_QUERY_SQL, (code, ))
                pos_vec = cur.fetchone()
            return pos_vec

        finally:
            self.conn_pool.putconn(conn)

    def get_need_update_codes(self):
        """获得所有需要更新的公司代码"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                cur.execute(self.NEED_UPDADE_CODE_SQL)
                result = cur.fetchall()
            return result
        finally:
            self.conn_pool.putconn(conn)

    def insert_main_target(self, targets):
        """保存主要指标"""
        conn = self.conn_pool.getconn()
        try:
            with conn.cursor() as cur:
                cur.executemany(self.MAIN_TARGET_INSERT_SQL, targets)
                conn.commit()
        except errors.UniqueViolation as e:
            print("当前记录已存在")
        except errors.NotNullViolation as e:
            print("违反非空约束")
        finally:
            self.conn_pool.putconn(conn)

    @staticmethod
    def init(database_config):
        DBUtils.__lock.acquire()
        try:
            if DBUtils.__singleton is None:
                DBUtils.__singleton = DBUtils(database_config)
        except Exception:
            traceback.print_exc()
        finally:
            DBUtils.__lock.release()
        return DBUtils.__singleton
Beispiel #23
0
class Database(rigor.database.Database):
	""" Container for a database connection pool """

	def __init__(self, database):
		super(Database, self).__init__(database)
		register_type(psycopg2.extensions.UNICODE)
		register_uuid()
		dsn = Database.build_dsn(database)
		self._pool = ThreadedConnectionPool(config.get('database', 'min_database_connections'), config.get('database', 'max_database_connections'), dsn)

	@staticmethod
	def build_dsn(database):
		""" Builds the database connection string from config values """
		dsn = "dbname='{0}' host='{1}'".format(database, config.get('database', 'host'))
		try:
			ssl = config.getboolean('database', 'ssl')
			if ssl:
				dsn += " sslmode='require'"
		except ConfigParser.Error:
			pass
		try:
			username = config.get('database', 'username')
			dsn += " user='******'".format(username)
		except ConfigParser.Error:
			pass
		try:
			password = config.get('database', 'password')
			dsn += " password='******'".format(password)
		except ConfigParser.Error:
			pass
		return dsn

	@staticmethod
	@template
	def create(name):
		""" Creates a new database with the given name """
		return "CREATE DATABASE {0};".format(name)

	@staticmethod
	@template
	def drop(name):
		""" Drops the database with the given name """
		return "DROP DATABASE {0};".format(name)

	@staticmethod
	@template
	def clone(source, destination):
		"""
		Copies the source database to a new destination database.  This may fail if
		the source database is in active use.
		"""
		return "CREATE DATABASE {0} WITH TEMPLATE {1};".format(destination, source)

	@contextmanager
	def get_cursor(self, commit=True):
		""" Gets a cursor from a connection in the pool """
		connection = self._pool.getconn()
		cursor = connection.cursor(cursor_factory=RigorCursor)
		try:
			yield cursor
		except psycopg2.IntegrityError as error:
			exc_info = sys.exc_info()
			self.rollback(cursor)
			raise rigor.database.IntegrityError, exc_info[1], exc_info[2]
		except psycopg2.DatabaseError as error:
			exc_info = sys.exc_info()
			self.rollback(cursor)
			raise rigor.database.DatabaseError, exc_info[1], exc_info[2]
		except:
			exc_info = sys.exc_info()
			self.rollback(cursor)
			raise exc_info[0], exc_info[1], exc_info[2]
		else:
			if commit:
				self.commit(cursor)
			else:
				self.rollback(cursor)

	def _close_cursor(self, cursor):
		""" Closes a cursor and releases the connection to the pool """
		cursor.close()
		self._pool.putconn(cursor.connection)

	def commit(self, cursor):
		""" Commits the transaction, then closes the cursor """
		cursor.connection.commit()
		self._close_cursor(cursor)

	def rollback(self, cursor):
		""" Rolls back the transaction, then closes the cursor """
		cursor.connection.rollback()
		self._close_cursor(cursor)

	def __del__(self):
		self._pool.closeall()
Beispiel #24
0
class PostgresqlUtil(DBInterface):
    def __init__(self):
        super(PostgresqlUtil, self).__init__()
        self.conn_pool = ThreadedConnectionPool(
            minconn=1,
            maxconn=32,
            database=config.POSTGRESQL_DATABASE,
            user=config.POSTGRESQL_USER_NAME,
            password=config.POSTGRESQL_PASSWORD,
            host=config.POSTGRESQL_HOST,
            port=config.POSTGRESQL_PORT)

    def get_conn(self):
        conn = self.conn_pool.getconn()
        return conn

    def put_conn(self, conn):
        self.conn_pool.putconn(conn)

    def close(self):
        self.conn_pool.closeall()

    # 插入
    def insert(self, table, data, **kwargs):
        insert_sql = generate_sql_string("INSERT", table, data)
        conn = self.get_conn()
        cur = conn.cursor()

        try:
            cur.execute(insert_sql)
        except Exception as e:
            conn.rollback()
            raise e
        finally:
            cur.close()
            conn.commit()
            self.put_conn(conn)

    # 查找
    def select(self, table, condition, **kwargs):
        select_sql = generate_sql_string("SELECT",
                                         table,
                                         dict_condition=condition)
        conn = self.get_conn()
        cur = conn.cursor()

        res = []
        try:
            cur.execute(select_sql)
            res = cur.fetchall()
        except Exception as e:
            conn.rollback()
            raise e
        finally:
            cur.close()
            conn.commit()
            self.put_conn(conn)
        return res

    # 删除
    def delete(self, table, condition, **kwargs):
        delete_sql = generate_sql_string("DELETE",
                                         table,
                                         dict_condition=condition)
        conn = self.get_conn()
        cur = conn.cursor()

        try:
            cur.execute(delete_sql)
        except Exception as e:
            conn.rollback()
            raise e
        finally:
            cur.close()
            conn.commit()
            self.put_conn(conn)

    # 修改
    def update(self, table, condition, data, upsert=False, **kwargs):
        item = self.select(table, condition)
        if len(item) == 0 and upsert is False:
            raise ValueError()
        elif len(item) == 0 and upsert is True:
            self.insert(table, data)
        else:  # len(item) > 0
            update_sql = generate_sql_string("UPDATE", table, data, condition)
            conn = self.get_conn()
            cur = conn.cursor()

            try:
                cur.execute(update_sql)
            except Exception as e:
                conn.rollback()
                raise e
            finally:
                cur.close()
                conn.commit()
                self.put_conn(conn)
Beispiel #25
0
class Scraper(object):
    ''' Parent class of all scrapers '''
    def __init__(self):
        self.NOW = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.connect_db()
        proxies = self.load_proxies()
        headers = self.load_user_headers()
        self.proxy_pool = cycle(proxies)
        self.header_pool = cycle(headers)

    # +  -  -  - PROXIES & HEADERS -  -  - +

    def load_proxies(self):
        ''' Load proxies from csv file and return a set of proxies'''
        proxies = set()
        try:
            df = pd.read_csv("./src/utils/proxy_files/proxies.csv")
            for i, r in df.iterrows():
                proxy = ':'.join([r['IP Address'], str(r['Port'])[:-2]])
                proxies.add(proxy)
        except Exception as e:
            print(e)

        return proxies

    def load_user_headers(self):
        ''' Load headers from csv file and return a set of headers'''
        headers = set()
        df = pd.read_csv("./src/utils/proxy_files/user_agents.csv")
        for i, r in df.iterrows():
            headers.add(r['User agent'])
        return headers

    def get_proxies(self):
        ''' Get the next proxy in proxy pool'''
        proxy = next(self.proxy_pool)
        return {"http": proxy, "https": proxy}

    def get_headers(self):
        ''' Get the next header in header pool'''
        headers = next(self.header_pool)
        return {"User-Agent": headers}

    def reset_proxy_pool(self):
        ''' Download new proxies, save to csv and load csv'''
        download_free_proxies()
        proxies = self.load_proxies()
        self.proxy_pool = cycle(proxies)

    # +  -  -  - DATABASE -  -  - +

    def connect_db(self):
        ''' Connect to db on cloud '''

        host_name = "bitko.czflnkl5jbgy.ap-southeast-2.rds.amazonaws.com"

        self._connpool = ThreadedConnectionPool(1,
                                                2,
                                                database="bitko",
                                                user="******",
                                                password="******",
                                                host=host_name,
                                                port="5432")

    @contextmanager
    def cursor(self):
        ''' Get a cursor from the conn pool '''

        # Get available connection from pool
        conn = self._connpool.getconn()
        conn.autocommit = True
        try:
            # Return a generator cursor() created on the fly
            yield conn.cursor()
        finally:
            # Return the connection back to connection pool
            self._connpool.putconn(conn)

    def query_list(self, sql, values):
        with self.cursor() as cur:
            cur.execute(sql, values)
            results = [i for i in cur.fetchall()]
        return results

    def query_one(self, sql, values):
        with self.cursor() as cur:
            cur.execute(sql, values)
            result = cur.fetchone()
        return result

    def execute(self, sql, values):
        ''' Execute sql command'''
        with self.cursor() as cur:
            cur.execute(sql, values)
Beispiel #26
0
class PostGISCatalog(Catalog):
    def __init__(self,
                 table="footprints",
                 database_url=os.getenv("DATABASE_URL"),
                 geometry_column="geom"):
        if database_url is None:
            raise Exception("Database URL must be provided.")
        urlparse.uses_netloc.append('postgis')
        urlparse.uses_netloc.append('postgres')
        url = urlparse.urlparse(database_url)

        self._pool = ThreadedConnectionPool(1,
                                            16,
                                            database=url.path[1:],
                                            user=url.username,
                                            password=url.password,
                                            host=url.hostname,
                                            port=url.port)

        self._log = logging.getLogger(__name__)
        self.table = table
        self.geometry_column = geometry_column

    def _candidates(self,
                    bounds,
                    resolution,
                    min_zoom,
                    max_zoom,
                    include_geometries=False):
        self._log.info("Resolution: %s; zoom range: %d-%d", resolution,
                       min_zoom, max_zoom)

        # TODO get sources in native CRS of the target
        query = """
            WITH bbox AS (
              SELECT ST_SetSRID(
                'BOX(%(minx)s %(miny)s, %(maxx)s %(maxy)s)'::box2d,
                4326) geom
            ),
            sources AS (
              SELECT
                 url,
                 source,
                 resolution,
                 coalesce(bands, '{{}}'::jsonb) bands,
                 coalesce(meta, '{{}}'::jsonb) meta,
                 coalesce(recipes, '{{}}'::jsonb) recipes,
                 acquired_at,
                 priority,
                 ST_Multi(footprints.geom) geom,
                 filename,
                 min_zoom,
                 max_zoom
               FROM {table} footprints
               JOIN bbox ON footprints.geom && bbox.geom
               WHERE numrange(min_zoom, max_zoom, '[]') && numrange(%(min_zoom)s, %(max_zoom)s, '[]')
                 AND footprints.enabled = true
            )
            SELECT
              url,
              source,
              resolution,
              bands,
              meta,
              recipes,
              acquired_at,
              null band,
              priority,
              null coverage,
              CASE WHEN {include_geometries}
                  THEN ST_AsGeoJSON(geom)
                  ELSE 'null'
              END geom,
              filename,
              min_zoom,
              max_zoom
            FROM sources
        """.format(table=self.table,
                   geometry_column=self.geometry_column,
                   include_geometries=bool(include_geometries))

        if bounds.crs == WGS84_CRS:
            left, bottom, right, top = bounds.bounds
        else:
            left, bottom, right, top = warp.transform_bounds(
                bounds.crs, WGS84_CRS, *bounds.bounds)

        connection = self._pool.getconn()
        try:
            with connection as conn, conn.cursor() as cur:
                cur.execute(
                    query, {
                        "minx": left if left != Infinity else -180,
                        "miny": bottom if bottom != Infinity else -90,
                        "maxx": right if right != Infinity else 180,
                        "maxy": top if top != Infinity else 90,
                        "min_zoom": min_zoom,
                        "max_zoom": max_zoom,
                        "resolution": min(resolution),
                    })

                for record in cur:
                    yield Source(*record[:-4],
                                 geom=json.loads(record[-4]),
                                 filename=record[-3],
                                 min_zoom=record[-2],
                                 max_zoom=record[-1])
        except Exception as e:
            self._log.error(e)
        finally:
            self._pool.putconn(connection)

    def _fill_bounds(self, bounds, resolution, include_geometries=False):
        zoom = get_zoom(max(resolution))
        query = """
            WITH RECURSIVE bbox AS (
              SELECT ST_SetSRID(
                    'BOX(%(minx)s %(miny)s, %(maxx)s %(maxy)s)'::box2d,
                    4326) geom
            ),
            date_range AS (
              SELECT
                COALESCE(min(acquired_at), '1970-01-01') min,
                COALESCE(max(acquired_at), '1970-01-01') max,
                age(COALESCE(max(acquired_at), '1970-01-01'),
                    COALESCE(min(acquired_at), '1970-01-01')) "interval"
              FROM {table}
            ),
            sources AS (
              SELECT * FROM (
                SELECT
                  1 iterations,
                  ARRAY[url] urls,
                  ARRAY[source] sources,
                  ARRAY[resolution] resolutions,
                  ARRAY[coalesce(bands, '{{}}'::jsonb)] bands,
                  ARRAY[coalesce(meta, '{{}}'::jsonb)] metas,
                  ARRAY[coalesce(recipes, '{{}}'::jsonb)] recipes,
                  ARRAY[acquired_at] acquisition_dates,
                  ARRAY[priority] priorities,
                  ARRAY[ST_Area(ST_Intersection(bbox.geom, footprints.geom)) /
                    ST_Area(bbox.geom)] coverages,
                  ARRAY[ST_Multi(footprints.geom)] geometries,
                  ST_Multi(footprints.geom) geom,
                  ST_Difference(bbox.geom, footprints.geom) uncovered
                FROM date_range, {table} footprints
                JOIN bbox ON footprints.geom && bbox.geom
                WHERE %(zoom)s BETWEEN min_zoom AND max_zoom
                  AND footprints.enabled = true
                ORDER BY
                  10 * coalesce(footprints.priority, 0.5) *
                    .1 * (1 - (extract(
                      EPOCH FROM (current_timestamp - COALESCE(
                        acquired_at, '2000-01-01'))) /
                        extract(
                          EPOCH FROM (current_timestamp - date_range.min)))) *
                    50 *
                      -- de-prioritize over-zoomed sources
                      CASE WHEN %(resolution)s / footprints.resolution >= 1
                        THEN 1
                        ELSE 1 / footprints.resolution
                      END *
                    ST_Area(
                        ST_Intersection(bbox.geom, footprints.geom)) /
                      ST_Area(bbox.geom) DESC
                LIMIT 1
              ) AS _
              UNION ALL
              SELECT * FROM (
                SELECT
                  sources.iterations + 1,
                  sources.urls || url urls,
                  sources.sources || source sources,
                  sources.resolutions || resolution resolutions,
                  sources.bands || coalesce(
                    footprints.bands, '{{}}'::jsonb) bands,
                  sources.metas || coalesce(meta, '{{}}'::jsonb) metas,
                  sources.recipes || coalesce(
                    footprints.recipes, '{{}}'::jsonb) recipes,
                  sources.acquisition_dates || footprints.acquired_at
                    acquisition_dates,
                  sources.priorities || footprints.priority priorities,
                  sources.coverages || ST_Area(
                    ST_Intersection(sources.uncovered, footprints.geom)) /
                    ST_Area(bbox.geom) coverages,
                  sources.geometries || footprints.geom,
                  ST_Collect(sources.geom, footprints.geom) geom,
                  ST_Difference(sources.uncovered, footprints.geom) uncovered
                FROM bbox, date_range, {table} footprints
                -- use proper intersection to prevent voids from irregular
                -- footprints
                JOIN sources ON ST_Intersects(
                    footprints.geom, sources.uncovered)
                WHERE NOT (footprints.url = ANY(sources.urls))
                  AND %(zoom)s BETWEEN min_zoom AND max_zoom
                  AND footprints.enabled = true
                ORDER BY
                  10 * coalesce(footprints.priority, 0.5) *
                    .1 * (1 - (extract(
                      EPOCH FROM (current_timestamp - COALESCE(
                        acquired_at, '2000-01-01'))) /
                        extract(
                          EPOCH FROM (current_timestamp - date_range.min)))) *
                    50 *
                      -- de-prioritize over-zoomed sources
                      CASE WHEN %(resolution)s / footprints.resolution >= 1
                        THEN 1
                        ELSE 1 / footprints.resolution
                      END *
                    ST_Area(
                        ST_Intersection(sources.uncovered, footprints.geom)) /
                        ST_Area(bbox.geom) DESC
                LIMIT 1
              ) AS _
            ),
            candidates AS (
                SELECT *
                FROM sources
                ORDER BY iterations DESC
                LIMIT 1
            ), candidate_rows AS (
                SELECT
                  unnest(urls) url,
                  unnest(sources) source,
                  unnest(resolutions) resolution,
                  unnest(bands) bands,
                  unnest(metas) meta,
                  unnest(recipes) recipes,
                  unnest(acquisition_dates) acquired_at,
                  unnest(priorities) priority,
                  unnest(coverages) coverage,
                  unnest(geometries) geom
                FROM candidates
            )
            SELECT
              url,
              source,
              resolution,
              bands,
              meta,
              recipes,
              acquired_at,
              null band,
              priority,
              coverage,
              CASE WHEN {include_geometries}
                  THEN ST_AsGeoJSON(geom)
                  ELSE 'null'
              END geom
            FROM candidate_rows
        """.format(table=self.table,
                   geometry_column=self.geometry_column,
                   include_geometries=bool(include_geometries))

        if bounds.crs == WGS84_CRS:
            left, bottom, right, top = bounds.bounds
        else:
            left, bottom, right, top = warp.transform_bounds(
                bounds.crs, WGS84_CRS, *bounds.bounds)

        connection = self._pool.getconn()
        try:
            with connection as conn, conn.cursor() as cur:
                cur.execute(
                    query, {
                        "minx": left if left != Infinity else -180,
                        "miny": bottom if bottom != Infinity else -90,
                        "maxx": right if right != Infinity else 180,
                        "maxy": top if top != Infinity else 90,
                        "zoom": zoom,
                        "resolution": min(resolution),
                    })

                for record in cur:
                    yield Source(*record[:-1], geom=json.loads(record[-1]))
        except Exception as e:
            self._log.error(e)
        finally:
            self._pool.putconn(connection)

    def get_sources(self,
                    bounds,
                    resolution,
                    min_zoom=None,
                    max_zoom=None,
                    include_geometries=False):
        if min_zoom is None or max_zoom is None:
            return self._fill_bounds(bounds,
                                     resolution,
                                     include_geometries=include_geometries)

        return self._candidates(bounds,
                                resolution,
                                min_zoom,
                                max_zoom,
                                include_geometries=include_geometries)
Beispiel #27
0
class PgPool:
    def __init__(self):
        logger.debug('initializing postgres threaded pool')
        self.host, self.port = None, None
        self.database, self.pool = None, None
        self.user, self.passwd = None, None
        self.pool = None

        logger.debug('Server Addr: {host}:{port}; Database: {db}; User: {user}; Password: {passwd}'.format(
            host=self.host, port=self.port,
            db=self.database, user=self.user, passwd=self.passwd
        ))

    def create_pool(self, conn_dict, limits):
        """
        Create a connection pool

        :param conn_dict: connection params dictionary
        :type conn_dict: dict
        """
        if conn_dict["Host"] is None:
            self.host = 'localhost'
        else:
            self.host = conn_dict["Host"]
        if conn_dict["Port"] is None:
            self.port = '5432'
        else:
            self.port = conn_dict["Port"]

        self.database = conn_dict["Database"]
        self.user = conn_dict["User"]
        self.passwd = conn_dict["Password"]

        conn_params = "host='{host}' dbname='{db}' user='******' password='******' port='{port}'".format(
            host=self.host, db=self.database, user=self.user, passwd=self.passwd, port=self.port
        )

        try:
            logger.debug('creating pool')
            self.pool = ThreadedConnectionPool(int(limits["Min"]), int(limits["Max"]), conn_params)
        except Exception as e:
            logger.exception(e.message)

    def get_conn(self):
        """
        Get a connection from pool and return connection and cursor
        :return: conn, cursor
        """
        logger.debug('getting connection from pool')
        try:
            conn = self.pool.getconn()
            cursor = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
            return conn, cursor
        except Exception as e:
            logger.exception(e.message)
            return None, None

    @staticmethod
    def execute_query(cursor, query, params):
        """
        Execute a query on database

        :param cursor: cursor object
        :param query: database query
        :type query: str
        :param params: query parameters
        :type params: tuple
        :return: query results or bool
        """
        logger.info('executing query')
        logger.debug('Cursor: {cursor}, Query: {query}'.format(
            cursor=cursor, query=query))

        try:
            if query.split()[0].lower() == 'select':
                cursor.execute(query, params)
                return cursor.fetchall()
            else:
                return cursor.execute(query, params)
        except Exception as e:
            logger.exception(e.message)
            return False

    # commit changes to db permanently
    @staticmethod
    def commit_changes(conn):
        """
        Commit changes to the databse permanently

        :param conn: connection object
        :return: bool
        """
        logger.debug('commiting changes to database')
        try:
            return conn.commit()
        except Exception as e:
            logger.exception(e.message)
            return False

    def put_conn(self, conn):
        """
        Put connection back to the pool

        :param conn: connection object
        :return: bool
        """
        logger.debug('putting connection {conn} back to pool'.format(conn=conn))
        try:
            return self.pool.putconn(conn)
        except Exception as e:
            logger.exception(e.message)
            return False

    def close_pool(self):
        """
        Closes connection pool
        :return: bool
        """
        logger.debug('closing connections pool')
        try:
            return self.pool.closeall()
        except Exception as e:
            logger.exception(e.message)
            return False
Beispiel #28
0
class PgUtil():
    # __metaclass__ = Singleton

    def __init__(self):
        self.conn_pool = ThreadedConnectionPool(minconn=4,
                                                maxconn=100,
                                                database='coda',
                                                user='******',
                                                password='******',
                                                host='127.0.0.1',
                                                port=5432)

    def get_conn(self):
        conn = self.conn_pool.getconn()
        return conn

    def put_conn(self, conn):
        self.conn_pool.putconn(conn)

    def execute_insert_sql(self, sql, values):
        conn = self.get_conn()
        cur = conn.cursor(cursor_factory=DictCursor)
        cur.execute(sql, values)
        cur.close()
        conn.commit()
        self.put_conn(conn)

    def query_all_sql(self, sql):
        conn = self.get_conn()
        cur = conn.cursor(cursor_factory=RealDictCursor)
        cur.execute(sql)
        result_list = []
        for row in cur.fetchall():
            result_list.append(row)
        conn.commit()
        self.put_conn(conn)
        return result_list

    def query_one_sql(self, sql):
        conn = self.get_conn()
        cur = conn.cursor(cursor_factory=DictCursor)
        cur.execute(sql)
        row = cur.fetchone()
        cur.close()
        conn.commit()
        self.put_conn(conn)
        return dict(row) if row else {}

    def execute_sql(self, sql):
        conn = self.get_conn()
        cur = conn.cursor(cursor_factory=DictCursor)
        cur.execute(sql)
        cur.close()
        conn.commit()
        self.put_conn(conn)

    def select_sql(self, sql, values):
        conn = self.get_conn()
        cur = conn.cursor(cursor_factory=DictCursor)
        cur.execute(sql, values)
        res = cur.fetchone()
        cur.close()
        conn.commit()
        self.put_conn(conn)
        return res

    def select_all_sql(self, sql, values):
        conn = self.get_conn()
        cur = conn.cursor(cursor_factory=DictCursor)
        cur.execute(sql, values)
        res = cur.fetchall()
        cur.close()
        conn.commit()
        self.put_conn(conn)
        return res

    def execute_update_sql(self, sql, values):
        conn = self.get_conn()
        cur = conn.cursor(cursor_factory=DictCursor)
        cur.execute(sql, values)
        cur.close()
        conn.commit()
        self.put_conn(conn)
Beispiel #29
0
class Pool:

    _pools = {}

    def __init__(self, cfg):
        db_uri = cfg.get("db_uri", DEFAULT_DB_URI)
        self.cfg = cfg
        uri = urlparse(db_uri)
        dbname = uri.path[1:]
        self.flavor = uri.scheme
        self.pg_schema = None
        if self.flavor == "sqlite":
            self.conn_args = [dbname]
            self.conn_kwargs = {
                "check_same_thread": False,
                "detect_types": sqlite3.PARSE_DECLTYPES,
                "isolation_level": "DEFERRED",
            }
            sqlite3.register_converter("JSONB", json.loads)
            sqlite3.register_converter("INTEGER[]", convert_array(int))
            sqlite3.register_converter("VARCHAR[]", convert_array(str))
            sqlite3.register_converter("FLOAT[]", convert_array(float))
            sqlite3.register_converter(
                "BOOL[]", convert_array(lambda x: x == "True")
            )

        elif self.flavor == "postgresql":
            self.pg_schema = uri.fragment
            if psycopg2 is None:
                raise ImportError(
                    'Cannot connect to "%s" without psycopg2 package '
                    "installed" % db_uri
                )

            con_info = "dbname='%s' " % dbname
            if uri.hostname:
                con_info += "host='%s' " % uri.hostname
            if uri.username:
                con_info += "user='******' " % uri.username
            if uri.password:
                con_info += "password='******' " % uri.password
            if uri.port:
                con_info += "port='%s' " % uri.port

            self.pg_pool = ThreadedConnectionPool(
                cfg.get("pg_min_pool_size", 1),
                cfg.get("pg_max_pool_size", 10),
                con_info,
            )
        elif self.flavor == "crdb":
            if psycopg2 is None:
                raise ImportError(
                    'Cannot connect to "%s" without psycopg2 package '
                    "installed" % db_uri
                )
            # transform crdb into postgreql in uri scheme to please
            # psycopg2
            uri_parts = list(uri)
            uri_parts[0] = "postgresql"
            self.db_uri = urlunparse(uri_parts)

        else:
            raise ValueError(
                'Unsupported scheme "%s" in uri "%s"' % (uri.scheme, uri)
            )

    def enter(self):
        if self.flavor == "sqlite":
            connection = sqlite3.connect(*self.conn_args, **self.conn_kwargs)
            connection.text_factory = str
            connection.execute("PRAGMA foreign_keys=ON")
            connection.execute("PRAGMA journal_mode=wal")
        elif self.flavor == "crdb":
            connection = psycopg2.connect(self.db_uri)
        elif self.flavor == "postgresql":
            connection = self.pg_pool.getconn()
            if self.pg_schema:
                qr = "SET search_path TO %s" % self.pg_schema
                connection.cursor().execute(qr)

        else:
            raise ValueError('Unexpected flavor "%s"' % self.flavor)
        return connection

    def leave(self, connection, exc=None):
        if exc:
            logger.debug("ROLLBACK")
            connection.rollback()
        else:
            logger.debug("COMMIT")
            connection.commit()
        if self.flavor == "postgresql":
            self.pg_pool.putconn(connection)
        else:
            connection.close()

    @classmethod
    def disconnect(cls):
        for pool in cls._pools.values():
            if pool.flavor == "postgresql":
                pool.pg_pool.closeall()
        cls.clear()

    @classmethod
    def clear(cls):
        cls._pools = {}

    @classmethod
    def get_pool(cls, cfg):
        db_uri = cfg.get("db_uri", DEFAULT_DB_URI)
        pool = cls._pools.get(db_uri)
        if pool:
            # Return existing pool for current db if any
            return pool

        pool = Pool(cfg)
        cls._pools[db_uri] = pool
        return pool
Beispiel #30
0
class PgPool:
    def __init__(self):
        self.host, self.port = None, None
        self.database, self.pool = None, None
        self.user, self.passwd = None, None
        self.pool = None

    def create_pool(self, conn_dict):
        """
        Create a connection pool

        :param conn_dict: connection params dictionary
        :type conn_dict: dict
        """
        if conn_dict["Host"] is None:
            self.host = 'localhost'
        else:
            self.host = conn_dict["Host"]
        if conn_dict["Port"] is None:
            self.port = '5432'
        else:
            self.port = conn_dict["Port"]

        self.database = conn_dict["Database"]
        self.user = conn_dict["User"]
        self.passwd = conn_dict["Password"]

        conn_params = "host='{host}' dbname='{db}' user='******' password='******' port='{port}'".format(
            host=self.host,
            db=self.database,
            user=self.user,
            passwd=self.passwd,
            port=self.port)

        try:
            self.pool = ThreadedConnectionPool(1, 50, conn_params)
        except Exception as e:
            print(e.message)

    def get_conn(self):
        """
        Get a connection from pool and return connection and cursor
        :return: conn, cursor
        """
        try:
            conn = self.pool.getconn()
            cursor = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
            return conn, cursor
        except Exception as e:
            print(e.message)
            return None, None

    @staticmethod
    def execute_query(cursor, query, params):
        """
        Execute a query on database

        :param cursor: cursor object
        :param query: database query
        :type query: str
        :param params: query parameters
        :type params: tuple
        :return: query results or bool
        """

        try:
            if query.split()[0].lower() == 'select':
                cursor.execute(query, params)
                return cursor.fetchall()
            else:
                return cursor.execute(query, params)
        except Exception as e:
            print(e.message)
            return False

    # commit changes to db permanently
    @staticmethod
    def commit_changes(conn):
        """
        Commit changes to the database permanently

        :param conn: connection object
        :return: bool
        """
        try:
            return conn.commit()
        except Exception as e:
            print(e.message)
            return False

    def put_conn(self, conn):
        """
        Put connection back to the pool

        :param conn: connection object
        :return: bool
        """
        try:
            return self.pool.putconn(conn)
        except Exception as e:
            print(e.message)
            return False

    def close_pool(self):
        """
        Closes connection pool
        :return: bool
        """
        try:
            return self.pool.closeall()
        except Exception as e:
            print(e.message)
            return False
Beispiel #31
0
class PgConnManager:
    """
    Manage postgresql database connections.
    """

    # dict to keep track of PgConnManager instances
    _instances = {}

    class DatabaseError(psycopg2.Error):
        """Raised when database connection error occurs."""

        pass

    def __new__(self, *kargs, **kwargs):
        db_opts = kargs[0]

        if "logger" in kwargs:
            logger = kwargs.pop("logger")
        else:
            logger = logging.getLogger("sjutils.pgconnmanager")

        # We can either accept 'database' or 'dbname' as an input
        if "database" in db_opts and "dbname" not in db_opts:
            db_opts["dbname"] = db_opts["database"]

        # This is ugly but since dict types cannot be used
        # as keys in another dict, we need to transform it.
        # This transformation was chosen as it is human
        # readable, it could be changed to a more optimised one.
        db_str = (
            "host=%(host)s port=%(port)s user=%(user)s password=%(password)s dbname=%(dbname)s"
            % db_opts)

        if db_str not in self._instances:
            self._instances[db_str] = super(PgConnManager, self).__new__(self)
            self._instances[db_str].lock = threading.Lock()
            self._instances[db_str].log = logger

        return self._instances[db_str]

    def __init__(self, db_opts):
        self.__params__ = db_opts
        if "isolation_level" not in self.__params__:
            self.__params__[
                "isolation_level"] = psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED
        if not hasattr(self, "__conn_pool__"):
            self.__conn_pool__ = None

    def decorate(self, func, *args, **kw):
        try:
            ctx_list = self.connect()
            try:
                try:
                    ret = func(self, ctx_list, *args, **kw)
                except psycopg2.OperationalError as _error:
                    # We got a database disconnection not catched by user, wiping all connection because
                    # psycopg2 does not fill correctly the database connection 'closed' attribute in case of disconnection
                    self.release_all(ctx_list, close=True)
                    raise

                # Connection(s) wasn't released by user, so we have to release it/them
                self.release_all(ctx_list, rollback=True)

                return ret
            except psycopg2.Error as _error:
                self.release_all(ctx_list, rollback=True)
                raise
            except Exception:
                self.release_all(ctx_list, rollback=True)
                raise
        except psycopg2.Error as _error:
            # We do not want our users to have to 'import psycopg2' to
            # handle the module's underlying database errors
            _, value, traceback = sys.exc_info()
            raise self.DatabaseError(value).with_traceback(traceback)

    def _new_ctx(self, mark=None):
        """Create a new context object."""
        if not mark:
            mark = "none"
        ret = {"conn": None, "cursor": None, "mark": mark}
        ret["conn"] = self.__conn_pool__.getconn()
        self.log.debug(
            "ctx:(" + mark + ", " + str(id(ret)) +
            ") Creating context to database: %(dbname)s as %(user)s" %
            self.__params__)
        return ret

    def connect(self, ctx_list=None, mark=None):
        """Connect to database."""
        self.log.debug("Connecting to database: %(dbname)s as %(user)s" %
                       self.__params__)
        try:
            self.lock.acquire_lock()
            try:
                if not self.__conn_pool__:
                    connector = (
                        "host=%(host)s port=%(port)s user=%(user)s password=%(password)s dbname=%(dbname)s"
                        % self.__params__)
                    self.__conn_pool__ = ThreadedConnectionPool(
                        1, 20, connector)
            finally:
                self.lock.release_lock()
            if not ctx_list:
                ctx_list = []
            ctx_list.append(self._new_ctx(mark))
            self.log.debug(
                "ctx:(%(mark)s, %(ctxid)s) Connected to database: %(dbname)s as %(user)s"
                % {
                    "mark": ctx_list[-1]["mark"],
                    "ctxid": str(id(ctx_list[-1])),
                    "dbname": self.__params__["dbname"],
                    "user": self.__params__["user"],
                })
            return ctx_list

        except psycopg2.Error as _error:
            # We do not want our users to have to 'import psycopg2' to
            # handle the module's underlying database errors
            _, value, traceback = sys.exc_info()
            raise self.DatabaseError(value).with_traceback(traceback)

    def execute(self, ctx, query, options=None):
        """Execute an SQL query."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) Executing query on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })
        try:
            if not ctx["conn"]:
                ctx["conn"] = self.__conn_pool__.getconn()
            if ("isolation_level" in self.__params__
                    and self.__params__["isolation_level"] !=
                    ctx["conn"].isolation_level):
                ctx["conn"].set_isolation_level(
                    self.__params__["isolation_level"])
            if not ctx["cursor"]:
                ctx["cursor"] = ctx["conn"].cursor(
                    cursor_factory=psycopg2.extras.DictCursor)
            try:
                if options:
                    ctx["cursor"].execute(query, options)
                else:
                    ctx["cursor"].execute(query)
            except psycopg2.OperationalError as _error:
                # We got a database disconnection, wiping connection
                self.release(ctx, close=True)
                raise
        except psycopg2.Error as _error:
            # We do not want our users to have to 'import psycopg2' to
            # handle the module's underlying database errors
            _, value, traceback = sys.exc_info()
            raise self.DatabaseError(value).with_traceback(traceback)

    def commit(self, ctx):
        """Commit changes to dabatase."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) Commiting changes on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })
        try:
            try:
                ctx["conn"].commit()
            except psycopg2.OperationalError as _error:
                # We got a database disconnection, wiping connection
                self.release(ctx, close=True)
                raise
        except psycopg2.Error as _error:
            # We do not want our users to have to 'import psycopg2' to
            # handle the module's underlying database errors
            _, value, traceback = sys.exc_info()
            raise self.DatabaseError(value).with_traceback(traceback)

    def rollback(self, ctx):
        """Rollback changes to database."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) Reverting changes on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })
        # FIXME: ctx['conn'] is None if execute failed, see bug #3085
        if ctx["conn"] and not ctx["conn"].closed:
            ctx["conn"].rollback()

    def release(self, ctx, rollback=False, close=False):
        """Release database connection."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) Releasing connection on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })

        if rollback:
            self.rollback(ctx)

        ctx["cursor"] = None
        if ctx["conn"]:
            self.log.debug(
                "ctx:(%(mark)s, %(ctxid)s) Disposing connection on database: %(dbname)s as %(user)s"
                % {
                    "mark": ctx["mark"],
                    "ctxid": str(id(ctx)),
                    "dbname": self.__params__["dbname"],
                    "user": self.__params__["user"],
                })
            closeit = close or (ctx["conn"].closed > 0)
            self.__conn_pool__.putconn(ctx["conn"], close=closeit)
        ctx["conn"] = None

    def release_all(self, ctx_list, rollback=False, close=False):
        """Release all database connections from a context list."""
        self.log.debug(
            "Releasing ALL connections on database: %(dbname)s as %(user)s" %
            self.__params__)

        for ctx in ctx_list:
            self.release(ctx, rollback=rollback, close=close)
        ctx_list = None

    def fetchall(self, ctx):
        """Return all rows of current request."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) fetchall on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })
        return ctx["cursor"].fetchall()

    def fetchmany(self, ctx, arraysize=1000):
        """Return @arraysize rows of current request."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) fetchmany on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })
        return ctx["cursor"].fetchmany(arraysize)

    def fetchgenerator(self, ctx):
        """A basic generator to iterate through the current request's
        result rows."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) fetchgenerator on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })
        while True:
            results = self.fetchmany(ctx)
            if not results:
                return
            for result in results:
                yield result

    def get_rowcount(self, ctx):
        """Get the cursor's rowcount attribute of the given context @ctx."""
        return ctx["cursor"].rowcount

    def fetchone(self, ctx):
        """Return one row of current request."""
        self.log.debug(
            "ctx:(%(mark)s, %(ctxid)s) fetchone on database: %(dbname)s as %(user)s"
            % {
                "mark": ctx["mark"],
                "ctxid": str(id(ctx)),
                "dbname": self.__params__["dbname"],
                "user": self.__params__["user"],
            })
        return ctx["cursor"].fetchone()

    def set_isolation_level(self, isolation_level):
        """Set the isolation level."""
        self.__params__["isolation_level"] = isolation_level

    def get_isolation_level(self):
        """Set the isolation level."""
        return self.__params__["isolation_level"]
Beispiel #32
0
class Database:
    """Jumbo's abstract PostgreSQL client manager.

    Allows to manage multiple independent connections to a PostgreSQL database
    and to handle arbitrary transactions between clients and database.

    Usage:

        .. code-block:: python

            # Initialize abstract database manager
            database = jumbo.database.Database()

            # Open a connection pool to the PostgreSQL database:
            with database.open() as pool:

                # Use a connection from the pool
                with pool.connect():

                    # Execute SQL query on the database
                    pool.send(SQL_query)

            # context managers ensure connections are properly returned to the
            pool, and that the pool is properly closed.
    """

    def __init__(self, env_path: Optional[str] = None) -> None:
        """Initializes manager to handle connections to a given PostgreSQL
        database.

        Args:
            env_path:   path where to look for the jumbo.env configuration
                        file. If not provided, looks for configuration file
                        in the working directory of the script invoking this
                        constructor.

        Attributes:
            config (jumbo.config.Config):   jumbo's database configuration
                                            settings.
            pool (psycopg2.pool):           connection pool to the PostgreSQL
                                            database. Initially closed.
        """

        # Database configuration settings
        self.config = Config(env_path)
        # initialize a placeholder connection pool
        self.pool = ThreadedConnectionPool(0, 0)
        self.pool.closed = True  # keep it closed on construction

        # Log configuration settings
        logger.info(f"Jumbo Connection Manager created:\n{self.config}")

    @contextmanager
    def open(self, minconns: int = 1, maxconns: Optional[int] = None) -> None:
        """Context manager opening a connection pool to the PostgreSQL
        database. The pool is automatically closed on exit and all
        connections are properly handled.

        Example:

            .. code-block:: python

                with self.open() as pool:

                    # do something with the pool

                # pool is automatically closed here

        Args:
            minconns:   minimum amount of available connections in the pool,
                        created on startup.
            maxconns:   maximum amount of available connections supported by
                        the pool. Defaults to minconns.
        """
        try:

            # Create connection pool
            self.open_pool(minconns, maxconns)

            yield self

        except OperationalError as e:
            logger.error(f"Error while opening connection pool: {e}")

        finally:
            # Close all connections in the pool
            self.close_pool()

    def open_pool(self, minconns: int = 1,
                  maxconns: Optional[int] = None) -> None:
        """Initialises and opens a connection pool to the PostgreSQL database
        (psycopg2.pool.ThreadedConnectionPool).

        'minconn' new connections are created  immediately. The connection pool
        will support a maximum of about 'maxconn' connections.

        Args:
            minconns:   minimum amount of available connections in the pool,
                        created on startup.
            maxconns:   maximum amount of available connections supported by
                        the pool. Defaults to 'minconns'.
        """

        # If the pool hasn't been opened yet
        if self.pool.closed:

            # initialize max number of supported connections
            maxconns = maxconns if maxconns is not None else minconns

            # create a connection pool based on jumbo's configuration settings
            self.pool = ThreadedConnectionPool(
                minconns, maxconns,
                host=self.config.DATABASE_HOST,
                user=self.config.DATABASE_USERNAME,
                password=self.config.DATABASE_PASSWORD,
                port=self.config.DATABASE_PORT,
                dbname=self.config.DATABASE_NAME,
                sslmode='disable')

            logger.info(f"Connection pool created to PostgreSQL database: "
                        f"{maxconns} connections available.")

    def close_pool(self) -> None:
        """Closes all connections in the pool, making it unusable by clients.
        """

        # If the pool hasn't been closed yet
        if not self.pool.closed:
            self.pool.closeall()
            logger.info("All connections in the pool have been closed "
                        "successfully.")

    @contextmanager
    def connect(self, key: int = 1) -> None:
        """Context manager opening a connection to the PostgreSQL database
        using a connection [key] from the pool. The connection is
        automatically closed on exit and all transactions are properly handled.

        Example:

            .. code-block:: python

                # check-out connection from the pool
                with self.connect(key):

                    # do something with the connection e.g.
                    # self.send(sql_query, key)

                # connection is automatically closed here

        Args:
            key (optional): key to identify the connection being opened.
                            Required for proper book keeping.
        """

        # check-out an available connection from the pool
        self.get_connection(key=key)

        try:

            yield self

        except (Exception, KeyboardInterrupt) as e:
            logger.error(f"Error raised during connection [{key}] "
                         f"transactions: {e}")

        finally:
            # return the connection to the pool
            self.put_back_connection(key=key)

    def get_connection(self, key: int = 1) -> None:
        """Connect to a Postgres database using an available connection from
        pool. The connection is assigned to 'key' on checkout.

        Args:
            key (optional): key to assign to the connection being opened.
                            Required for proper book keeping.
        """

        # If a pool has been opened
        if not self.pool.closed:

            try:

                # If the specific connection hasn't been already opened
                if key not in self.pool._used:

                    # Connect to PostgreSQL database
                    self.pool.getconn(key)
                    logger.info(f"Connection retrieved successfully: pool "
                                f"connection [{key}] now in use.")

                    # perform handshake
                    self.on_connection(key)

                else:
                    logger.warning(f"Pool connection [{key}] is already in "
                                   f"use by another client. Try a different "
                                   f"key.")

            except PoolError as error:
                logger.error(f"Error while retrieving connection from "
                             f"pool:\t{error}")
                sys.exit()

        else:
            logger.warning(f"No pool to the PostgreSQL database: cannot "
                           f"retrieve a connection. Try to .open() a pool.")

    def on_connection(self, key: int = 1) -> None:
        """Client-database handshaking script to perform on retrieval of a
        PostgreSQL connection from the pool.

        Args:
            key (optional): key of the pool connection being used in
                            the transaction. Defaults to [1].
        """

        # return database information
        info = self.connection_info(key=key)
        logger.info(f"You are connected to - {info}")

    def put_back_connection(self, key: int = 1) -> None:
        """Puts back a connection in the connection pool.

        Args:
            key (optional): key of the pool connection being used in the
                            transaction. Defaults to [1].
        """

        # If this specific connection is under use
        if key in self.pool._used:

            # Reset connection to neutral state
            self.pool._used[key].reset()
            # Put back connection in the pool
            self.pool.putconn(self.pool._used[key], key)

            logger.info(f"Connection returned successfully: pool connection "
                        f"[{key}] now available again.")

        else:
            logger.warning(f"Pool connection [{key}] has never been opened: "
                           f"cannot put it back in the pool.")

    def send(self,
             query: Union[str, sql.Composed],
             subs: Optional[Tuple[str, ...]] = None,
             cur_method: int = 0,
             file: Optional[IO] = None,
             fetch_method: int = 2,
             key: int = 1) -> Union[DictRow, None]:
        """Sends an arbitrary PostgreSQL query to the PostgreSQL database.
        Transactions are auto-committed on execution.

        Example:

            .. code-block:: python

                # A simple query with no substitutions
                query = 'SELECT * FROM table_name;'
                results = self.send(query)

                # A more complex query with dynamic substitutions
                query = 'INSERT INTO table_name (column_name, another_column_name) VALUES (%s, $s);'
                subs = (value, another_value)
                results = self.send(query, subs)

        Args:
            query:                      PostgreSQL command string (can be
                                        template with psycopg2 %s fields).
            subs (optional):            tuple of values to substitute in SQL
                                        query  template (cf. psycopg2 %s
                                        formatting)
            cur_method (optional):      code to select which psycopg2 cursor
                                        execution method to use for the SQL
                                        query:
                                        0:  cursor.execute()
                                        1:  cursor.copy_expert()
            file (optional):            file-like object to read or write to
                                        (only relevant if cur_method:1).
            fetch_method (optional):    code to select which psycopg2 result
                                        retrieval method to use (fetch*()):
                                        0: cur.fetchone()
                                        2: cur.fetchall()
            key (optional):             key of the pool connection being used
                                        in the transaction. Defaults to [1].

        Returns:
            list of query results (if any). Can be accessed as dictionaries.
        """

        # If this specific connection has already been opened
        if key in self.pool._used:

            try:  # try running a transaction

                with self.pool._used[key].cursor(
                        cursor_factory=DictCursor) as cur:

                    # Bind arguments to query string (if present)
                    if subs is not None:
                        query = cur.mogrify(query, subs)
                    else:
                        query = cur.mogrify(query)

                    # Execute query
                    if cur_method == 0:
                        cur.execute(query)
                    elif cur_method == 1:
                        cur.copy_expert(sql=query, file=file)

                    # Fetch query results
                    try:
                        if fetch_method == 0:
                            records = cur.fetchone()
                        elif fetch_method == 2:
                            records = cur.fetchall()
                    # Handle SQL queries that don't return any results
                    # (INSERT, UPDATE, etc...)
                    except ProgrammingError:
                        records = []
                        pass

                    # Commit transaction
                    self.pool._used[key].commit()

                    # Display success message (and shorten query if too long)
                    s_query = query
                    if len(query) > 78:
                        s_query = (str(query[:75]) + '...')
                    success_msg = f"Successfully sent: {s_query} "
                    if cur.rowcount >= 0:
                        success_msg += f": {cur.rowcount} rows affected."
                    logger.info(success_msg)

                    return records  # dictionaries

            except (Exception, Error, DatabaseError) as e:
                # Rollback transaction if any problem
                self.pool._used[key].rollback()
                logger.error(f"Error while sending query {query}:{e}. "
                             f"Transaction rolled-back.")

        else:
            logger.warning(f"Pool connection [{key}] has never been opened: "
                           f"not available for transactions.")

    def listen_on_channel(self, channel_name: str, key: int = 1) -> None:
        """Subscribes to a PostgreSQL notification channel by listening for
        NOTIFYs.

        .. code-block:: postgresql

            -- Command executed:
            LISTEN channel_name;

        Args:
            channel_name:   channel on which to LISTEN. PostgreSQL database
                            should be configured to send NOTIFYs on this
                            channel.
            key (optional): key of the pool connection being used in the
                            transaction. Defaults to [1].
        """

        query = "LISTEN " + channel_name + ";"
        self.send(query, key=key)

    def connection_info(self, key: int = 1) -> DictRow:
        """Fetches PostgreSQL database version.

        .. code-block:: postgresql

            -- Command executed:
            SELECT version();

        Args:
            key (optional): key of the pool connection being used in the
                            transaction. Defaults to [1].

        Returns:
            query result. Contains PostgreSQL database version information.
        """

        query = "SELECT version();"
        info = self.send(query, fetch_method=0, key=key)  # fetchone()
        return info

    def copy_to_table(self, query: sql.Composed, file: IO, db_table: str,
                      schema: str = None, replace: bool = True,
                      key: int = 1) -> None:
        """Utility wrapper to send a SQL query to copy data to database table.
        Allows to replace table if it already exists in the database.

        .. code-block:: postgresql

            -- Command type expected:
            COPY table_name [ ( column_name [, ...] ) ]
                FROM STDIN
                [ [ WITH ] ( option [, ...] ) ]

            -- Ancillary command pre-executed:
            TRUNCATE table_name;

        Example:

            .. code-block:: python

                # Copy csv data from file to a table in the database
                query = "COPY table_name FROM STDIN WITH CSV DELIMITER '\\t'"
                results = self.copy_to_table(query, file="C:\\data.csv",
                                             db_table='table_name')


        Args:
            query:              PostgreSQL COPY command string.
            file:               absolute path to file-like object to read data
                                from.
            db_table:           the name (not schema-qualified) of an
                                existing database table.
            schema (optional):  schema to which ``db_table`` belongs. If
                                ``None``, use default schema.
            replace (optional): replaces table contents if True. Appends data
                                to table contents otherwise.
            key (optional):     key of the pool connection being used in the
                                transaction. Defaults to [1].
        """

        # Replace the table already existing in the database
        if replace:
            # schema-qualify table name if needed
            identifier = sql.Identifier(schema, db_table) if schema is not None else sql.Identifier(db_table)
            # pass table name dynamically to query
            query_tmp = sql.SQL("TRUNCATE {};").format(identifier)
            self.send(query_tmp, key=key)

        # Copy the table from file (cur_method:1 = cur.copy_expert)
        self.send(query, cur_method=1, file=file, key=key)

    def copy_df(self, df: pd.DataFrame, db_table: str, schema: str = None,
                replace: bool = False, key: int = 1) -> None:
        """Utility wrapper to efficiently copy a pandas.DataFrame to a
        PostgreSQL database table.

        This method is faster than panda's native *.to_sql()* method and
        exploits PostgreSQL COPY TO command. Provides a useful mean of saving
        results from a pandas-centred data analysis pipeline directly to the
        database.

        Args:
            df:                 dataframe to be copied.
            db_table:           the name (not schema-qualified) of the
                                table to write to.
            schema (optional):  schema to which ``db_table`` belongs. If
                                ``None``, use default schema.
            replace (optional): replaces table contents if True. Appends data
                                to table contents otherwise.
            key (optional):     key of the pool connection being used in the
                                transaction. Defaults to [1].
        """

        if key in self.pool._used:

            try:
                # Create headless csv from pandas dataframe
                io_file = io.StringIO()
                df.to_csv(io_file, sep='\t', header=False, index=False)
                io_file.seek(0)

                # Quickly create a table with correct number of columns / data
                # types Unfortunately we will need to quickly build a
                # sqlalchemy engine for this hack to work
                replacement_method = 'replace' if replace else 'append'
                engine = create_engine('postgresql+psycopg2://',
                                       creator=lambda: self.pool._used[key])
                df.head(0).to_sql(name=db_table, con=engine, schema=schema,
                                  if_exists=replacement_method, index=False)

                # But then exploit postgreSQL COPY command instead of slow
                # pandas .to_sql(). Note that replace is set to false in
                # copy_table as we want to preserve the header table created
                # above

                # schema-qualify table name if needed
                identifier = sql.Identifier(schema, db_table) if schema is not None else sql.Identifier(db_table)

                sql_copy_expert = sql.SQL(
                    "COPY {} FROM STDIN WITH CSV DELIMITER '\t'").format(
                    identifier)
                self.copy_to_table(sql_copy_expert, file=io_file,
                                   db_table=db_table, schema=schema,
                                   replace=False, key=key)

                logger.info(f"DataFrame copied successfully to PostgreSQL "
                            f"table.")

            except (Exception, DatabaseError) as error:
                logger.error(f"Error while copying DataFrame to PostgreSQL "
                             f"table: {error}")

        else:
            logger.warning(f"Pool connection [{key}] has never been opened: "
                           f"cannot use it to copy Dataframe to database.")
class DatabaseManager:
    """
    This class provides abstraction over underlying database.
    """
    
    def __init__(self, db_name="test_db", db_pass="", host="127.0.0.1" , port="5432"):
        self.connection_pool = ThreadedConnectionPool(10, 50, database=db_name, user="******", \
                                                       password=db_pass, host=host, port=port)
        self.logger = get_logger()
        

    def __execute_query(self, query):
        connection = self.connection_pool.getconn()
        cursor = connection.cursor()
        self.logger.debug("Going to execute query {}".format(query))
        try:
            cursor.execute(query)
        except ProgrammingError:
            self.logger.error("Error occurred while executing query {}".format(query))
        except IntegrityError:
            self.logger.error("Query failed. Duplicate row for query {}".format(query))
        finally:
            connection.commit()
            self.connection_pool.putconn(connection)

    """
    Inserts multiple rows in table_name. column_headers contain tuple of table headers.
    rows contain the list of tuples where each tuple has values for each rows. The values in 
    tuple are ordered according to column_headers tuple.
    """
    def insert_batch(self, table_name, column_headers, rows):
        query = "INSERT INTO {} {} VALUES {}".format(table_name, '(' + ','.join(column_headers) + ')', str(rows)[1:-1])
        self.__execute_query(query)

    """
    Updates a row(uid) with new values from column_vs_value dict.
    """
    def update(self, table_name, column_vs_value, uid):
        update_str = ''.join('{}={},'.format(key, val) for key, val in column_vs_value.items())[:-1]
        query = "UPDATE {} SET {} WHERE id = {} ".format(table_name, update_str, uid) 
        self.__execute_query(query)
    
    """
    Deletes all rows from table_name with uids. uids is a tuple.
    """
    def delete_batch(self, table_name , uids, uid_column_name='id'):
        query = "DELETE from {} WHERE {} in {}".format(table_name, uid_column_name, str(uids))
        self.__execute_query(query)
    
    """
    Returns the dict a row by uid.
    """
    def get_row(self, table_name, uid, uid_column_name='id'):
        query = "Select * from {} where {} = {}".format(table_name, uid_column_name, uid)
        connection = self.connection_pool.getconn()
        cursor = connection.cursor()
        cursor.execute(query)
        column_names = [desc[0] for desc in cursor.description]
        values = cursor.fetchall()
        result = {}
        if len(values) > 0:
            for x, y in itertools.izip(column_names, values[0]):
                result[x] = y
        self.connection_pool.putconn(connection)
        return result
    
    """
    Returns all distinct values of column_name from table_name.
    """
    def get_all_values_for_attr(self, table_name, column_name):
        query = "Select distinct {} from {}".format(column_name, table_name)
        connection = self.connection_pool.getconn()
        cursor = connection.cursor()
        cursor.execute(query)
        rows = cursor.fetchall()
        uids = [row[0] for row in rows]
        self.connection_pool.putconn(connection)
        return uids
    
    """
    Returns all rows from table_name satisfying where_clause. The number of returned rows are limited to 
    limit.
    """
    def get_all_rows(self, table_name, where_clause='1=1', limit=20, order_by=None):
        query = "Select * from {} where {} ".format(table_name, where_clause)
        if order_by:
            query = '{} order by {} desc'.format(query, order_by)
        query = '{} limit {}'.format(query, limit)
        connection = self.connection_pool.getconn()
        cursor = connection.cursor()
        cursor.execute(query)
        column_names = [desc[0] for desc in cursor.description]
        rows = cursor.fetchall()
        result = []
        for row in rows:
            result_row = {}
            for x, y in itertools.izip(column_names, row):
                result_row[x] = str(y)
            result.append(result_row)
        self.connection_pool.putconn(connection)
        return result
    
    """
    Gets a new connection from the pool and returns the connection object.
    """
    def get_connection(self):
        return self.connection_pool.getconn()
        
    """
    Releases the connection back to pool.
    """
    def release_connection(self, connection):
        self.connection_pool.putconn(connection)
Beispiel #34
0
class DB(object):
    connections = 0

    def __init__(self, config):
        self.debug = True if (DEBUG and Debug_Level >= 3) else False
        if isinstance(config, DB):
            self.DataName = config.DataName
            self.pool = config.pool
            self.cnx = self.cur = None
            return
        try:
            self.DataName = config['datatype']
            del config['datatype']
        except:
            self.DataName = 'POSTGRESQL'
        if self.DataName == 'MYSQL':
            try:
                self.pool = mysql.connector.pooling.MySQLConnectionPool(
                    **config)
                self.cnx = self.cur = None
            except mysql.connector.Error as err:
                # 这里需要记录操作日志
                logging.debug(err.msg)
                self.cnx = None
                raise BaseError(701)  # 与数据库连接异常
        elif self.DataName == 'POSTGRESQL':
            try:
                self.pool = ThreadedConnectionPool(**config)
            except:
                raise BaseError(701)  # 与数据库连接异常
        elif self.DataName == 'ORACLE':
            try:
                if config['NLS_LANG']:
                    os.environ['NLS_LANG'] = config['NLS_LANG']
                del config['NLS_LANG']
            except:
                pass
            try:
                self.pool = cx_Oracle.SessionPool(**config)
            except:
                raise BaseError(701)  # 与数据库连接异常
        #恢复删除的配置
        config['datatype'] = self.DataName

    def setDebug(self, debug):
        self.debug = debug

    def clone(self):
        db = DB(self)
        return db

    def open(self, auto=False):
        try:

            DB.connections += 1
            # print("===================================db.open, " + str(DB.connections))

            if self.DataName == 'ORACLE':
                self.__conn = self.pool.acquire()
                self.__cursor = self.__conn.cursor()
            elif self.DataName == 'POSTGRESQL':
                self.__conn = self.pool.getconn()
                self.__cursor = self.__conn.cursor()
            else:  # 默认为Mysql
                self.__conn = self.pool.get_connection()
                self.__cursor = self.__conn.cursor(buffered=True)

            #self.__conn.autocommit=True
            self.__conn.autocommit = auto
            self.cnx = self.__conn
            self.cur = self.__cursor
        except:
            raise BaseError(702)  # 无法获得连接池

    def close(self):

        DB.connections -= 1
        # print("===================================db.close, " + str(DB.connections))

        #关闭游标和数据库连接
        self.__conn.commit()
        if self.__cursor is not None:
            self.__cursor.close()

        if self.DataName == 'POSTGRESQL':
            self.pool.putconn(self.__conn)  #将数据库连接放回连接池中
        else:
            self.__conn.close()

        # print("===================================db.close end ")

    def begin(self):
        self.__conn.autocommit = False

    def commit(self):
        self.__conn.commit()

    def rollback(self):
        self.__conn.rollback()

#---------------------------------------------------------------------------

    def findBySql(self, sql, params={}, limit=0, join='AND', lock=False):
        """
			自定义sql语句查找
			limit = 是否需要返回多少行
			params = dict(field=value)
			join = 'AND | OR'
		"""
        try:
            cursor = self.__getCursor()
            sql = self.__joinWhere(sql, params, join)
            cursor.execute(sql, tuple(params.values()))
            rows = cursor.fetchmany(
                size=limit) if limit > 0 else cursor.fetchall()
            result = [dict(zip(cursor.column_names, row))
                      for row in rows] if rows else None
            return result
        except:
            raise BaseError(706)

    def countBySql(self, sql, params={}, join='AND'):
        # 自定义sql 统计影响行数
        try:
            cursor = self.__getCursor()
            sql = self.__joinWhere(sql, params, join)
            cursor.execute(sql, tuple(params.values()))
            result = cursor.fetchone()
            return result[0] if result else 0
        except:
            raise BaseError(707)

    def deleteByCond(self, table, cond):
        # 删除数据
        try:
            sql = "DELETE FROM %s where %s" % (table, cond)
            cursor = self.__getCursor()
            self.__display_Debug_IO(sql, ())  #DEBUG
            cursor.execute(sql)

            #self.__conn.commit() #
            return cursor.rowcount

        #except:
        #	raise BaseError(704)
        except Exception as err:
            raise BaseError(704, err._full_msg)

    #def updateByPk(self,table,data,id,pk='id'):
    #	# 根据主键更新,默认是id为主键
    #	return self.updateByAttr(table,data,{pk:id})

    def deleteByAttr(self, table, params={}, join='AND'):
        # 删除数据
        try:
            fields = ','.join(k + '=%s' for k in params.keys())
            sql = "DELETE FROM %s " % table
            sql = self.__joinWhere(sql, params, join)
            cursor = self.__getCursor()
            self.__display_Debug_IO(sql, tuple(params.values()))  #DEBUG
            cursor.execute(sql, tuple(params.values()))

            #self.__conn.commit() #
            return cursor.rowcount

        #except:
        #	raise BaseError(704)
        except Exception as err:
            raise BaseError(704, err._full_msg)

    def deleteByPk(self, table, id, pk='id'):
        # 根据主键删除,默认是id为主键
        return self.deleteByAttr(table, {pk: id})

    def findByAttr(self, table, criteria={}):
        # 根据条件查找一条记录
        return self.__query(table, criteria)

    def findByPk(self, table, id, pk='id'):
        return self.findByAttr(table, {'where': pk + '=' + str(id)})

    def findAllByAttr(self, table, criteria={}):
        # 根据条件查找记录
        return self.__query(table, criteria, True)

    def exit(self, table, params={}, join='AND'):
        # 判断是否存在
        return self.count(table, params, join) > 0

# 公共的方法 -------------------------------------------------------------------------------------

    def count(self, table, params={}, join='AND'):
        # 根据条件统计行数
        try:
            sql = 'SELECT COUNT(*) FROM %s' % table

            if params:
                where, whereValues = self.__contact_where(params)
                sqlWhere = ' WHERE ' + where if where else ''
                sql += sqlWhere

            #sql = self.__joinWhere(sql,params,join)
            cursor = self.__getCursor()

            self.__display_Debug_IO(sql, tuple(whereValues))  #DEBUG

            if self.DataName == 'ORACLE':
                cursor.execute(sql % tuple(whereValues))
            else:
                cursor.execute(sql, tuple(whereValues))
            #cursor.execute(sql,tuple(params.values()))
            result = cursor.fetchone()
            return result[0] if result else 0
        #except:
        #	raise BaseError(707)
        except Exception as err:
            try:
                raise BaseError(707, err._full_msg)
            except:
                raise BaseError(707)

    def getToListByPk(self, table, criteria={}, id=None, pk='id'):
        # 根据条件查找记录返回List
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria, isDict=False)

    def getAllToList(self, table, criteria={}, id=None, pk='id', join='AND'):
        # 根据条件查找记录返回List
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria, all=True, isDict=False)

    def getToObjectByPk(self, table, criteria={}, id=None, pk='id'):
        # 根据条件查找记录返回Object
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria)

    def getAllToObject(self, table, criteria={}, id=None, pk='id', join='AND'):
        # 根据条件查找记录返回Object
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria, all=True)

    def insert(self, table, data, commit=True):
        # 新增一条记录
        try:
            '''
				从data中分离含用SQL函数的字字段到funData字典中,
				不含SQL函数的字段到newData
			'''
            funData, newData = self.__split_expression(data)

            funFields = ''
            funValues = ''

            # 拼不含SQL函数的字段及值
            fields = ', '.join(k for k in newData.keys())
            values = ', '.join(("%s", ) * len(newData))

            # 拼含SQL函数的字段及值
            if funData:
                funFields = ','.join(k for k in funData.keys())
                funValues = ','.join(v for v in funData.values())

            # 合并所有字段及值
            fields += ', ' + funFields if funFields else ''
            values += ', ' + funValues if funValues else ''
            sql = 'INSERT INTO %s (%s) VALUES (%s)' % (table, fields, values)
            cursor = self.__getCursor()

            for (k, v) in newData.items():
                try:
                    if isinstance(v, str):
                        newData[k] = "'%s'" % (v, )
                    if v == None:
                        newData[k] = "null"
                except:
                    pass

            self.__display_Debug_IO(sql, tuple(newData.values()))  #DEBUG
            sql = sql % tuple(newData.values())

            if self.DataName == 'POSTGRESQL':
                sql += ' RETURNING id'

            cursor.execute(sql)

            #if self.DataName=='ORACLE':
            #sql= sql % tuple(newData.values())
            #cursor.execute(sql)
            #else :
            #cursor.execute(sql,tuple(newData.values()))

            if self.DataName == 'ORACLE':
                # 1. commit 一定要为假
                # 2. Oracle Sequence 的命名规范为: [用户名.]SEQ_表名_ID
                # 3. 每张主表都应该有ID
                t_list = table.split('.')
                if len(t_list) > 1:
                    SEQ_Name = t_list[0] + '.SEQ_' + t_list[1] + '_ID'
                else:
                    SEQ_Name = 'SEQ_' + t_list[0] + '_ID'

                cursor.execute('SELECT %s.CURRVAL FROM dual' %
                               SEQ_Name.upper())

                result = cursor.fetchone()
                insert_id = result[0] if result else 0
                #insert_id=cursor.rowcount
            elif self.DataName == 'MYSQL':
                insert_id = cursor.lastrowid
            elif self.DataName == 'POSTGRESQL':
                item = cursor.fetchone()
                insert_id = item[0]

            if commit: self.commit()
            return insert_id

        except Exception as err:
            try:
                raise BaseError(705, err._full_msg)
            except:
                raise BaseError(705, err.args)

    def update(self,
               table,
               data,
               params={},
               join='AND',
               commit=True,
               lock=True):
        # 更新数据
        try:
            fields, values = self.__contact_fields(data)
            if params:
                where, whereValues = self.__contact_where(params)

            values.extend(whereValues) if whereValues else values

            sqlWhere = ' WHERE ' + where if where else ''

            cursor = self.__getCursor()

            if commit: self.begin()

            if lock:
                sqlSelect = "SELECT %s From %s %s for update" % (','.join(
                    tuple(list(params.keys()))), table, sqlWhere)
                sqlSelect = sqlSelect % tuple(whereValues)
                cursor.execute(sqlSelect)  # 加行锁
                #cursor.execute(sqlSelect,tuple(whereValues))  # 加行锁

            sqlUpdate = "UPDATE %s SET %s " % (table, fields) + sqlWhere

            for index, val in enumerate(values):
                if isinstance(val, str):
                    values[index] = "'" + val + "'"
                if val == None:
                    values[index] = "null"

            self.__display_Debug_IO(sqlUpdate, tuple(values))  #DEBUG
            sqlUpdate = sqlUpdate % tuple(values)
            cursor.execute(sqlUpdate)

            #cursor.execute(sqlUpdate,tuple(values))

            if commit: self.commit()

            return cursor.rowcount

        except Exception as err:
            try:
                raise BaseError(705, err._full_msg)
            except:
                raise BaseError(705, err.args)

    def updateByPk(self, table, data, id, pk='id', commit=True, lock=True):
        # 根据主键更新,默认是id为主键
        return self.update(table, data, {pk: id}, commit=commit, lock=lock)

    def delete(self, table, params={}, join='AND', commit=True, lock=True):
        # 更新数据
        try:
            data = {}
            fields, values = self.__contact_fields(data)
            if params:
                where, whereValues = self.__contact_where(params)

            values.extend(whereValues) if whereValues else values

            sqlWhere = ' WHERE ' + where if where else ''

            cursor = self.__getCursor()

            if commit: self.begin()

            #if lock :
            #sqlSelect="SELECT %s From %s %s for update" % (','.join(tuple(list(params.keys()))),table,sqlWhere)
            #sqlSelect=sqlSelect % tuple(whereValues)
            #cursor.execute(sqlSelect)  # 加行锁
            ##cursor.execute(sqlSelect,tuple(whereValues))  # 加行锁

            sqlDelete = "DELETE FROM %s %s" % (table, sqlWhere)

            for index, val in enumerate(values):
                if isinstance(val, str):
                    values[index] = "'" + val + "'"

            self.__display_Debug_IO(sqlDelete, tuple(values))  #DEBUG
            sqlDelete = sqlDelete % tuple(values)
            cursor.execute(sqlDelete)

            #cursor.execute(sqlUpdate,tuple(values))

            if commit: self.commit()

            return cursor.rowcount

        except Exception as err:
            try:
                raise BaseError(705, err._full_msg)
            except:
                raise BaseError(705, err.args)

    def deleteByPk(self, table, id, pk='id', commit=True, lock=True):
        # 根据主键更新,默认是id为主键
        return self.delete(table, {pk: id}, commit=commit, lock=lock)

# 内部私有的方法 -------------------------------------------------------------------------------------

    def __display_Debug_IO(self, sql, params):
        # 不输出SQL语句,以减少屏显时间
        return
        if self.debug:
            debug_now_time = datetime.datetime.now().strftime(
                '%Y-%m-%d %H:%M:%S')
            print('[S ' + debug_now_time + ' SQL:] ' +
                  (sql % params) if params else sql)

    def __get_connection(self):
        return self.pool.get_connection()

    def __getCursor(self):
        """获取游标"""
        if self.__cursor is None:
            self.__cursor = self.__conn.cursor()
        return self.__cursor

    def getCursor(self):
        """获取游标"""
        if self.__cursor is None:
            self.__cursor = self.__conn.cursor()
        return self.__cursor

    def __joinWhere(self, sql, params, join):
        # 转换params为where连接语句
        if params:

            funParams = {}
            newParams = {}
            newWhere = ''
            funWhere = ''

            # 从params中分离含用SQL函数的字字段到Params字典中
            for (k, v) in params.items():
                if 'str' in str(type(v)) and '{{' == v[:2] and '}}' == v[-2:]:
                    funParams[k] = v[2:-2]
                else:
                    newParams[k] = v

            # 拼 newParams 条件
            keys, _keys = self.__tParams(newParams)
            newWhere = ' AND '.join(k + '=' + _k for k, _k in zip(
                keys, _keys)) if join == 'AND' else ' OR '.join(
                    k + '=' + _k for k, _k in zip(keys, _keys))

            # 拼 funParams 条件
            if funParams:
                funWhere = ' AND '.join(
                    k + '=' + v for k, v in
                    funParams.items()) if join == 'AND' else ' OR '.join(
                        k + '=' + v for k, v in funParams.items())

            # 拼最终的 where
            where = (
                (newWhere + ' AND ' if newWhere else '') +
                funWhere if funWhere else newWhere) if join == 'AND' else (
                    (newWhere + ' OR ' if newWhere else '') +
                    funWhere if funWhere else newWhere)

            #--------------------------------------
            #keys,_keys = self.__tParams(params)
            #where = ' AND '.join(k+'='+_k for k,_k in zip(keys,_keys)) if join == 'AND' else ' OR '.join(k+'='+_k for k,_k in zip(keys,_keys))
            sql += ' WHERE ' + where
        return sql

    def __tParams(self, params):
        keys = [k if k[:2] != '{{' else k[2:-2] for k in params.keys()]
        _keys = ['%s' for k in params.keys()]
        return keys, _keys

    def __query(self, table, criteria, all=False, isDict=True, join='AND'):
        '''
		   table	: 表名
		   criteria : 查询条件dict
		   all	  : 是否返回所有数据,默认为False只返回一条数据,当为真是返回所有数据
		   isDict   : 返回格式是否为字典,默认为True ,即字典否则为数组
		'''
        try:
            if all is not True:
                criteria['limit'] = 1  # 只输出一条
            sql, params = self.__contact_sql(table, criteria,
                                             join)  #拼sql及params
            '''
			# 当Where为多个查询条件时,拼查询条件 key 的 valuse 值
			if 'where' in criteria and 'dict' in str(type(criteria['where'])) :
				params = criteria['where']
				#params = tuple(params.values())
				where ,whereValues   = self.__contact_where(params)
				sql+= ' WHERE '+where if where else ''
				params=tuple(whereValues)
			else :
				params = None
			'''
            #__contact_where(params,join='AND')
            cursor = self.__getCursor()

            self.__display_Debug_IO(sql, params)  #DEBUG

            #if self.DataName=="ORACLE":
            #sql="select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % ()
            #pass

            cursor.execute(sql, params if params else ())

            rows = cursor.fetchall() if all else cursor.fetchone()

            if isDict:
                result = [dict(zip(cursor.column_names, row))
                          for row in rows] if all else dict(
                              zip(cursor.column_names, rows)) if rows else {}
            else:
                result = [row for row in rows] if all else rows if rows else []
            return result
        except Exception as err:
            try:
                raise BaseError(706, err._full_msg)
            except:
                raise BaseError(706)

    def __contact_sql(self, table, criteria, join='AND'):
        sql = 'SELECT '
        if criteria and type(criteria) is dict:
            #select fields
            if 'select' in criteria:
                fields = criteria['select'].split(',')
                sql += ','.join(
                    field.strip()[2:-2] if '{{' == field.strip()[:2]
                    and '}}' == field.strip()[-2:] else field.strip()
                    for field in fields)
            else:
                sql += ' * '
            #table
            sql += ' FROM %s' % table

            #where
            whereValues = None
            if 'where' in criteria:
                if 'str' in str(type(criteria['where'])):  # 当值为String时,即单一Key时
                    sql += ' WHERE ' + criteria['where']
                else:  # 当值为dict时,即一组key时
                    params = criteria['where']
                    #sql+= self.__joinWhere('',params,join)
                    #sql+=self.__contact_where(params,join)
                    where, whereValues = self.__contact_where(params)
                    sql += ' WHERE ' + where if where else ''
                    #sql=sql % tuple(whereValues)

            #group by
            if 'group' in criteria:
                sql += ' GROUP BY ' + criteria['group']
            #having
            if 'having' in criteria:
                sql += ' HAVING ' + criteria['having']

            if self.DataName == 'MYSQL':
                #order by
                if 'order' in criteria:
                    sql += ' ORDER BY ' + criteria['order']
                #limit
                if 'limit' in criteria:
                    sql += ' LIMIT ' + str(criteria['limit'])
                #offset
                if 'offset' in criteria:
                    sql += ' OFFSET ' + str(criteria['offset'])
            elif (self.DataName == 'POSTGRESQL'):
                #order by
                if 'order' in criteria:
                    sql += ' ORDER BY ' + criteria['order']
                if 'limit' in criteria:
                    # 取 offset,rowcount
                    arrLimit = (str(
                        criteria['limit']).split('limit ').pop()).split(',')
                    strOffset = arrLimit[0]
                    try:
                        strRowcount = arrLimit[1]
                    except:
                        strOffset = '0'
                        strRowcount = '1'
                    sql += '  LIMIT %s OFFSET %s' % (strRowcount, strOffset)

            elif (self.DataName == 'ORACLE') and ('limit' in criteria):
                # 取 offset,rowcount
                arrLimit = (str(
                    criteria['limit']).split('limit ').pop()).split(',')
                strOffset = arrLimit[0]
                try:
                    strRowcount = arrLimit[1]
                except:
                    strOffset = '0'
                    strRowcount = '1'

                # 处理 order by
                if 'order' in criteria:
                    strOrder = criteria['order']
                else:
                    strOrder = 'ROWNUM'
                # 以下Sql是针对 Oracle 的大数据查询效率
                sql = "select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % (
                    strOrder, sql, strOffset, strRowcount)
            elif (self.DataName == 'ORACLE') and ('order' in criteria):
                sql += ' ORDER BY ' + criteria['order']

        else:
            sql += ' * FROM %s' % table

        return sql, whereValues

    # 将字符串和表达式分离
    def __split_expression(self, data):
        funData = {}
        newData = {}
        funFields = ''

        # 从data中移出含用SQL函数的字字段到funData字典中
        for (k, v) in data.items():
            if 'str' in str(type(v)) and '{{' == v[:2] and '}}' == v[-2:]:
                funData[k] = v[2:-2]
            else:
                newData[k] = v

        return (funData, newData)

    # 拼Update字段
    def __contact_fields(self, data):

        funData, newData = self.__split_expression(data)
        if funData:
            funFields = ','.join(k + '=%s' % (v) for k, v in funData.items())
        fields = ','.join(k + '=%s' for k in newData.keys())

        # fields 与 funFields 合并
        if funData:
            fields = ','.join([fields, funFields]) if fields else funFields

        values = list(newData.values())

        return (fields, values)

    def __hasKeyword(self, key):
        if '{{}}' in key: return True
        if 'in (' in key: return True
        if 'like ' in key: return True
        if '>' in key: return True
        if '<' in key: return True
        return False

    # 拼Where条件
    def __contact_where(self, params, join='AND'):
        funParams, newParams = self.__split_expression(params)

        # 拼 newParams 条件
        keys, _keys = self.__tParams(newParams)
        newWhere = ' AND '.join(
            k + '=' + _k
            for k, _k in zip(keys, _keys)) if join == 'AND' else ' OR '.join(
                k + '=' + _k for k, _k in zip(keys, _keys))
        values = list(newParams.values())

        # 拼 funParams 条件
        #funWhere = ' AND '.join(('`' if k else '') +k+('`' if k else '')+ (' ' if self.__hasKeyword(v) else '=') +v for k,v in funParams.items()) if join == 'AND' else ' OR '.join('`'+k+'`'+(' ' if self.__hasKeyword(v) else '=')+v for k,v in funParams.items())

        funWhere = ' AND '.join(
            k + (' ' if self.__hasKeyword(v) else '=' if k else '') + v
            for k, v in funParams.items()) if join == 'AND' else ' OR '.join(
                k + (' ' if self.__hasKeyword(v) else '=' if k else '') + v
                for k, v in funParams.items())

        # 拼最终的 where
        where = ((newWhere + ' AND ' if newWhere else '') +
                 funWhere if funWhere else newWhere) if join == 'AND' else (
                     (newWhere + ' OR ' if newWhere else '') +
                     funWhere if funWhere else newWhere)
        return (where, values)

    def get_ids(self, list):  #从getAllToList返回中提取id
        try:
            test = list[0][0]
            dimension = 2
        except:
            dimension = 1

        ids = []
        if dimension > 1:
            for i in range(len(list)):
                ids.append(str(list[i][0]))
        else:
            for i in range(len(list)):
                ids.append(str(list[i]))

        return ','.join(ids)
Beispiel #35
0
class Psycopg2Backend(Backend):
    """Backend for accessing data stored in a Postgres database
    """

    display_name = "PostgreSQL"
    connection_pool = None
    auto_create_extensions = True

    def __init__(self, connection_params):
        super().__init__(connection_params)

        if self.connection_pool is None:
            self._create_connection_pool()

        if self.auto_create_extensions:
            self._create_extensions()

    def _create_connection_pool(self):
        try:
            self.connection_pool = ThreadedConnectionPool(
                1, 16, **self.connection_params)
        except Error as ex:
            raise BackendError(str(ex)) from ex

    def _create_extensions(self):
        for ext in EXTENSIONS:
            try:
                query = "CREATE EXTENSION IF NOT EXISTS {}".format(ext)
                with self.execute_sql_query(query):
                    pass
            except OperationalError:
                warnings.warn("Database is missing extension {}".format(ext))

    def create_sql_query(self, table_name, fields, filters=(),
                         group_by=None, order_by=None,
                         offset=None, limit=None,
                         use_time_sample=None):
        sql = ["SELECT", ', '.join(fields),
               "FROM", table_name]
        if use_time_sample is not None:
            sql.append("TABLESAMPLE system_time(%i)" % use_time_sample)
        if filters:
            sql.extend(["WHERE", " AND ".join(filters)])
        if group_by is not None:
            sql.extend(["GROUP BY", ", ".join(group_by)])
        if order_by is not None:
            sql.extend(["ORDER BY", ",".join(order_by)])
        if offset is not None:
            sql.extend(["OFFSET", str(offset)])
        if limit is not None:
            sql.extend(["LIMIT", str(limit)])
        return " ".join(sql)

    @contextmanager
    def execute_sql_query(self, query, params=None):
        connection = self.connection_pool.getconn()
        cur = connection.cursor()
        try:
            utfquery = cur.mogrify(query, params).decode('utf-8')
            log.debug("Executing: %s", utfquery)
            t = time()
            cur.execute(query, params)
            yield cur
            log.info("%.2f ms: %s", 1000 * (time() - t), utfquery)
        finally:
            connection.commit()
            self.connection_pool.putconn(connection)

    def quote_identifier(self, name):
        return '"%s"' % name

    def unquote_identifier(self, quoted_name):
        if quoted_name.startswith('"'):
            return quoted_name[1:len(quoted_name) - 1]
        else:
            return quoted_name

    def list_tables_query(self, schema=None):
        if schema:
            schema_clause = "AND n.nspname = '{}'".format(schema)
        else:
            schema_clause = "AND pg_catalog.pg_table_is_visible(c.oid)"
        return """SELECT n.nspname as "Schema",
                          c.relname AS "Name"
                       FROM pg_catalog.pg_class c
                  LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
                      WHERE c.relkind IN ('r','v','m','S','f','')
                        AND n.nspname <> 'pg_catalog'
                        AND n.nspname <> 'information_schema'
                        AND n.nspname !~ '^pg_toast'
                        {}
                        AND NOT c.relname LIKE '\\_\\_%'
                   ORDER BY 1;""".format(schema_clause)

    def create_variable(self, field_name, field_metadata,
                        type_hints, inspect_table=None):
        if field_name in type_hints:
            var = type_hints[field_name]
        else:
            var = self._guess_variable(field_name, field_metadata,
                                       inspect_table)

        field_name_q = self.quote_identifier(field_name)
        if var.is_continuous:
            if isinstance(var, TimeVariable):
                var.to_sql = ToSql("extract(epoch from {})"
                                   .format(field_name_q))
            else:
                var.to_sql = ToSql("({})::double precision"
                                   .format(field_name_q))
        else:  # discrete or string
            var.to_sql = ToSql("({})::text"
                               .format(field_name_q))
        return var

    def _guess_variable(self, field_name, field_metadata, inspect_table):
        type_code = field_metadata[0]

        FLOATISH_TYPES = (700, 701, 1700)  # real, float8, numeric
        INT_TYPES = (20, 21, 23)  # bigint, int, smallint
        CHAR_TYPES = (25, 1042, 1043,)  # text, char, varchar
        BOOLEAN_TYPES = (16,)  # bool
        DATE_TYPES = (1082, 1114, 1184, )  # date, timestamp, timestamptz
        # time, timestamp, timestamptz, timetz
        TIME_TYPES = (1083, 1114, 1184, 1266,)

        if type_code in FLOATISH_TYPES:
            return ContinuousVariable(field_name)

        if type_code in TIME_TYPES + DATE_TYPES:
            tv = TimeVariable(field_name)
            tv.have_date |= type_code in DATE_TYPES
            tv.have_time |= type_code in TIME_TYPES
            return tv

        if type_code in INT_TYPES:  # bigint, int, smallint
            if inspect_table:
                values = self.get_distinct_values(field_name, inspect_table)
                if values:
                    return DiscreteVariable(field_name, values)
            return ContinuousVariable(field_name)

        if type_code in BOOLEAN_TYPES:
            return DiscreteVariable(field_name, ['false', 'true'])

        if type_code in CHAR_TYPES:
            if inspect_table:
                values = self.get_distinct_values(field_name, inspect_table)
                if values:
                    return DiscreteVariable(field_name, values)

        return StringVariable(field_name)

    def count_approx(self, query):
        sql = "EXPLAIN " + query
        with self.execute_sql_query(sql) as cur:
            s = ''.join(row[0] for row in cur.fetchall())
        return int(re.findall(r'rows=(\d*)', s)[0])

    def __getstate__(self):
        # Drop connection_pool from state as it cannot be pickled
        state = dict(self.__dict__)
        state.pop('connection_pool', None)
        return state

    def __setstate__(self, state):
        # Create a new connection pool if none exists
        self.__dict__.update(state)
        if self.connection_pool is None:
            self._create_connection_pool()
Beispiel #36
0
class Psycopg2Backend(Backend):
    """Backend for accessing data stored in a Postgres database
    """

    display_name = "PostgreSQL"
    connection_pool = None
    auto_create_extensions = True

    def __init__(self, connection_params):
        super().__init__(connection_params)

        if self.connection_pool is None:
            self._create_connection_pool()

        self.missing_extension = []
        if self.auto_create_extensions:
            self._create_extensions()

    def _create_connection_pool(self):
        try:
            self.connection_pool = ThreadedConnectionPool(
                1, 16, **self.connection_params)
        except Error as ex:
            raise BackendError(str(ex)) from ex

    def _create_extensions(self):
        for ext in EXTENSIONS:
            try:
                query = "CREATE EXTENSION IF NOT EXISTS {}".format(ext)
                with self.execute_sql_query(query):
                    pass
            except BackendError:
                warnings.warn("Database is missing extension {}".format(ext))
                self.missing_extension.append(ext)

    def create_sql_query(self, table_name, fields, filters=(),
                         group_by=None, order_by=None,
                         offset=None, limit=None,
                         use_time_sample=None):
        sql = ["SELECT", ', '.join(fields),
               "FROM", table_name]
        if use_time_sample is not None:
            sql.append("TABLESAMPLE system_time(%i)" % use_time_sample)
        if filters:
            sql.extend(["WHERE", " AND ".join(filters)])
        if group_by is not None:
            sql.extend(["GROUP BY", ", ".join(group_by)])
        if order_by is not None:
            sql.extend(["ORDER BY", ",".join(order_by)])
        if offset is not None:
            sql.extend(["OFFSET", str(offset)])
        if limit is not None:
            sql.extend(["LIMIT", str(limit)])
        return " ".join(sql)

    @contextmanager
    def execute_sql_query(self, query, params=None):
        connection = self.connection_pool.getconn()
        cur = connection.cursor()
        try:
            utfquery = cur.mogrify(query, params).decode('utf-8')
            log.debug("Executing: %s", utfquery)
            t = time()
            cur.execute(query, params)
            yield cur
            log.info("%.2f ms: %s", 1000 * (time() - t), utfquery)
        except (Error, ProgrammingError) as ex:
            raise BackendError(str(ex)) from ex
        finally:
            connection.commit()
            self.connection_pool.putconn(connection)

    def quote_identifier(self, name):
        return '"%s"' % name

    def unquote_identifier(self, quoted_name):
        if quoted_name.startswith('"'):
            return quoted_name[1:len(quoted_name) - 1]
        else:
            return quoted_name

    def list_tables_query(self, schema=None):
        if schema:
            schema_clause = "AND n.nspname = '{}'".format(schema)
        else:
            schema_clause = "AND pg_catalog.pg_table_is_visible(c.oid)"
        return """SELECT n.nspname as "Schema",
                          c.relname AS "Name"
                       FROM pg_catalog.pg_class c
                  LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
                      WHERE c.relkind IN ('r','v','m','S','f','')
                        AND n.nspname <> 'pg_catalog'
                        AND n.nspname <> 'information_schema'
                        AND n.nspname !~ '^pg_toast'
                        {}
                        AND NOT c.relname LIKE '\\_\\_%'
                   ORDER BY 1;""".format(schema_clause)

    def create_variable(self, field_name, field_metadata,
                        type_hints, inspect_table=None):
        if field_name in type_hints:
            var = type_hints[field_name]
        else:
            var = self._guess_variable(field_name, field_metadata,
                                       inspect_table)

        field_name_q = self.quote_identifier(field_name)
        if var.is_continuous:
            if isinstance(var, TimeVariable):
                var.to_sql = ToSql("extract(epoch from {})"
                                   .format(field_name_q))
            else:
                var.to_sql = ToSql("({})::double precision"
                                   .format(field_name_q))
        else:  # discrete or string
            var.to_sql = ToSql("({})::text"
                               .format(field_name_q))
        return var

    def _guess_variable(self, field_name, field_metadata, inspect_table):
        type_code = field_metadata[0]

        FLOATISH_TYPES = (700, 701, 1700)  # real, float8, numeric
        INT_TYPES = (20, 21, 23)  # bigint, int, smallint
        CHAR_TYPES = (25, 1042, 1043,)  # text, char, varchar
        BOOLEAN_TYPES = (16,)  # bool
        DATE_TYPES = (1082, 1114, 1184, )  # date, timestamp, timestamptz
        # time, timestamp, timestamptz, timetz
        TIME_TYPES = (1083, 1114, 1184, 1266,)

        if type_code in FLOATISH_TYPES:
            return ContinuousVariable.make(field_name)

        if type_code in TIME_TYPES + DATE_TYPES:
            tv = TimeVariable.make(field_name)
            tv.have_date |= type_code in DATE_TYPES
            tv.have_time |= type_code in TIME_TYPES
            return tv

        if type_code in INT_TYPES:  # bigint, int, smallint
            if inspect_table:
                values = self.get_distinct_values(field_name, inspect_table)
                if values:
                    return DiscreteVariable.make(field_name, values)
            return ContinuousVariable.make(field_name)

        if type_code in BOOLEAN_TYPES:
            return DiscreteVariable.make(field_name, ['false', 'true'])

        if type_code in CHAR_TYPES:
            if inspect_table:
                values = self.get_distinct_values(field_name, inspect_table)
                # remove trailing spaces
                values = [v.rstrip() for v in values]
                if values:
                    return DiscreteVariable.make(field_name, values)

        return StringVariable.make(field_name)

    def count_approx(self, query):
        sql = "EXPLAIN " + query
        with self.execute_sql_query(sql) as cur:
            s = ''.join(row[0] for row in cur.fetchall())
        return int(re.findall(r'rows=(\d*)', s)[0])

    def __getstate__(self):
        # Drop connection_pool from state as it cannot be pickled
        state = dict(self.__dict__)
        state.pop('connection_pool', None)
        return state

    def __setstate__(self, state):
        # Create a new connection pool if none exists
        self.__dict__.update(state)
        if self.connection_pool is None:
            self._create_connection_pool()
Beispiel #37
0
class Database():

    def __init__(self, config, verbose_start=False):
        self.config = config

        if verbose_start:
            print('Connecting to database...'.ljust(5), end='')

        # get DB parameters
        try:
            self.database = config.getProperty('Database', 'name').lower()
            self.host = config.getProperty('Database', 'host')
            self.port = config.getProperty('Database', 'port')
            self.user = config.getProperty('Database', 'user').lower()
            self.password = config.getProperty('Database', 'password')
        except Exception as e:
            if verbose_start:
                LogDecorator.print_status('fail')
            raise Exception(f'Incomplete database credentials provided in configuration file (message: "{str(e)}").')

        try:
            self._createConnectionPool()
        except Exception as e:
            if verbose_start:
                LogDecorator.print_status('fail')
            raise Exception(f'Could not connect to database (message: "{str(e)}").')

        if verbose_start:
            LogDecorator.print_status('ok')


    def _createConnectionPool(self):
        self.connectionPool = ThreadedConnectionPool(
            1,
            self.config.getProperty('Database', 'max_num_connections', type=int, fallback=20),
            host=self.host,
            database=self.database,
            port=self.port,
            user=self.user,
            password=self.password,
            connect_timeout=2
        )


    def runServer(self):
        ''' Dummy function for compatibility reasons '''
        return



    @contextmanager
    def _get_connection(self):
        conn = self.connectionPool.getconn()
        conn.autocommit = True
        try:
            yield conn
        finally:
            self.connectionPool.putconn(conn, close=False)


    def execute(self, query, arguments, numReturn=None):
        with self._get_connection() as conn:
            cursor = conn.cursor(cursor_factory=RealDictCursor)

            # execute statement
            try:
                cursor.execute(query, arguments)
                conn.commit()
            except Exception as e:
                if not conn.closed:
                    conn.rollback()
                # self.connectionPool.putconn(conn, close=False)    #TODO: this still causes connection to close
                conn = self.connectionPool.getconn()

                # retry execution
                try:
                    cursor = conn.cursor(cursor_factory=RealDictCursor)
                    cursor.execute(query, arguments)
                    conn.commit()
                except:
                    if not conn.closed:
                        conn.rollback()
                    print(e)

            # get results
            try:
                returnValues = []
                if numReturn is None:
                    # cursor.close()
                    return
                
                elif numReturn == 'all':
                    returnValues = cursor.fetchall()
                    # cursor.close()
                    return returnValues

                else:
                    for _ in range(numReturn):
                        rv = cursor.fetchone()
                        if rv is None:
                            return returnValues
                        returnValues.append(rv)
        
                    # cursor.close()
                    return returnValues
            except Exception as e:
                print(e)
    

    def execute_cursor(self, query, arguments):
        with self._get_connection() as conn:
            cursor = conn.cursor(cursor_factory=RealDictCursor)
            try:
                cursor.execute(query, arguments)
                conn.commit()
                return cursor
            except:
                if not conn.closed:
                    conn.rollback()
                # cursor.close()

                # retry execution
                conn = self.connectionPool.getconn()
                try:
                    cursor = conn.cursor(cursor_factory=RealDictCursor)
                    cursor.execute(query, arguments)
                    conn.commit()
                except Exception as e:
                    if not conn.closed:
                        conn.rollback()
                    print(e)


    def insert(self, query, values):
        with self._get_connection() as conn:
            cursor = conn.cursor()
            try:
                execute_values(cursor, query, values)
                conn.commit()
            except Exception as e:
                if not conn.closed:
                    conn.rollback()
                # cursor.close()

                # retry execution
                conn = self.connectionPool.getconn()
                try:
                    cursor = conn.cursor(cursor_factory=RealDictCursor)
                    execute_values(cursor, query, values)
                    conn.commit()
                except Exception as e:
                    if not conn.closed:
                        conn.rollback()
                    print(e)
        rows = cur.fetchall()  # all rows in table
        if len(rows) < 1:
            return render_template('main.html',
                                   model=json.dumps(active_orders))
        phone = rows[0][0]

        return render_template('main.html',
                               model=json.dumps(active_orders),
                               phone=phone)

    except Exception, e:
        print Exception, ":", e
        return render_template('main.html', model="")
    finally:
        cur.close()
        pg_pool.putconn(conn)

        # return 'Hello!' + user_id


# API
@app.route('/login', methods=['POST'])
def login():
    phone = request.form['phone'].encode(encoding='utf-8')
    password_user = request.form['password'].encode(encoding='utf-8')

    try:
        conn = pg_pool.getconn()

        if conn is None:
            return '["false","db_connect_error"]'
Beispiel #39
0
	VALUES (%s, %s, %s, %s)
'''


def save_review(conn, user_id, data):
    try:
        cur = conn.cursor()
        cur.execute(insert_review_sql,
                    (user_id, data['book_id'], data['rating'], data['review']))
        conn.commit()

    except Exception as err:
        print('Error Saving review:', err, user_id)


if __name__ == "__main__":
    global tcp
    params = config()
    tcp = ThreadedConnectionPool(1, 10, **params)
    conn = tcp.getconn()
    review_info = [{
        'book_id': 24,
        'user_id': 'AG3HZ5KORQIXZHUBLNMZNU35SCEA',
        'name': '123',
        'rating': 4.0,
        'review': "This"
    }]

    save_user(conn, review_info)
    tcp.putconn(conn)
Beispiel #40
0
class Database(DatabaseInterface):

    _databases = {}
    _connpool = None
    _list_cache = None
    _list_cache_timestamp = None
    _version_cache = {}
    flavor = Flavor(ilike=True)

    def __new__(cls, name='template1'):
        if name in cls._databases:
            return cls._databases[name]
        return DatabaseInterface.__new__(cls, name=name)

    def __init__(self, name='template1'):
        super(Database, self).__init__(name=name)
        self._databases.setdefault(name, self)
        self._search_path = None
        self._current_user = None

    @classmethod
    def dsn(cls, name):
        uri = parse_uri(config.get('database', 'uri'))
        assert uri.scheme == 'postgresql'
        host = uri.hostname and "host=%s" % uri.hostname or ''
        port = uri.port and "port=%s" % uri.port or ''
        name = "dbname=%s" % name
        user = uri.username and "user=%s" % uri.username or ''
        password = ("password=%s" %
                    urllib.unquote_plus(uri.password) if uri.password else '')
        return '%s %s %s %s %s' % (host, port, name, user, password)

    def connect(self):
        if self._connpool is not None:
            return self
        logger.info('connect to "%s"', self.name)
        minconn = config.getint('database', 'minconn', default=1)
        maxconn = config.getint('database', 'maxconn', default=64)
        self._connpool = ThreadedConnectionPool(minconn, maxconn,
                                                self.dsn(self.name))
        return self

    def get_connection(self, autocommit=False, readonly=False):
        if self._connpool is None:
            self.connect()
        for count in range(config.getint('database', 'retry'), -1, -1):
            try:
                conn = self._connpool.getconn()
                break
            except PoolError:
                if count and not self._connpool.closed:
                    logger.info('waiting a connection')
                    time.sleep(1)
                    continue
                raise
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        if readonly:
            cursor = conn.cursor()
            cursor.execute('SET TRANSACTION READ ONLY')
        return conn

    def put_connection(self, connection, close=False):
        self._connpool.putconn(connection, close=close)

    def close(self):
        if self._connpool is None:
            return
        self._connpool.closeall()
        self._connpool = None

    @classmethod
    def create(cls, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('CREATE DATABASE "' + database_name + '" '
                       'TEMPLATE template0 ENCODING \'unicode\'')
        connection.commit()
        cls._list_cache = None

    def drop(self, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('DROP DATABASE "' + database_name + '"')
        Database._list_cache = None

    def get_version(self, connection):
        if self.name not in self._version_cache:
            cursor = connection.cursor()
            cursor.execute('SELECT version()')
            version, = cursor.fetchone()
            self._version_cache[self.name] = tuple(
                map(int,
                    RE_VERSION.search(version).groups()))
        return self._version_cache[self.name]

    @staticmethod
    def dump(database_name):
        from trytond.tools import exec_command_pipe

        cmd = ['pg_dump', '--format=c', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            # if db_password is set in configuration we should pass
            # an environment variable PGPASSWORD to our subprocess
            # see libpg documentation
            env['PGPASSWORD'] = uri.password
        cmd.append(database_name)

        pipe = exec_command_pipe(*tuple(cmd), env=env)
        pipe.stdin.close()
        data = pipe.stdout.read()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t dump database!')
        return data

    @staticmethod
    def restore(database_name, data):
        from trytond.tools import exec_command_pipe

        database = Database().connect()
        connection = database.get_connection(autocommit=True)
        database.create(connection, database_name)
        database.close()

        cmd = ['pg_restore', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            env['PGPASSWORD'] = uri.password
        cmd.append('--dbname=' + database_name)
        args2 = tuple(cmd)

        if os.name == "nt":
            tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam()
            with open(tmpfile, 'wb') as fp:
                fp.write(data)
            args2 = list(args2)
            args2.append(' ' + tmpfile)
            args2 = tuple(args2)

        pipe = exec_command_pipe(*args2, env=env)
        if not os.name == "nt":
            pipe.stdin.write(data)
        pipe.stdin.close()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t restore database')

        database = Database(database_name).connect()
        cursor = database.get_connection().cursor()
        if not database.test():
            cursor.close()
            database.close()
            raise Exception('Couldn\'t restore database!')
        cursor.close()
        database.close()
        Database._list_cache = None
        return True

    def list(self):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = Database._list_cache
        if res and abs(Database._list_cache_timestamp - now) < timeout:
            return res

        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute('SELECT datname FROM pg_database '
                       'WHERE datistemplate = false ORDER BY datname')
        res = []
        for db_name, in cursor:
            try:
                with connect(self.dsn(db_name)) as conn:
                    if self._test(conn):
                        res.append(db_name)
            except Exception:
                continue
        self.put_connection(connection)

        Database._list_cache = res
        Database._list_cache_timestamp = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        cursor = connection.cursor()
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        for module in ('ir', 'res'):
            state = 'uninstalled'
            if module in ('ir', 'res'):
                state = 'to install'
            info = get_module_info(module)
            cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
            module_id = cursor.fetchone()[0]
            cursor.execute(
                'INSERT INTO ir_module '
                '(id, create_uid, create_date, name, state) '
                'VALUES (%s, %s, now(), %s, %s)',
                (module_id, 0, module, state))
            for dependency in info.get('depends', []):
                cursor.execute(
                    'INSERT INTO ir_module_dependency '
                    '(create_uid, create_date, module, name) '
                    'VALUES (%s, now(), %s, %s)', (0, module_id, dependency))

        connection.commit()
        self.put_connection(connection)

    def test(self):
        connection = self.get_connection()
        is_tryton_database = self._test(connection)
        self.put_connection(connection)
        return is_tryton_database

    @classmethod
    def _test(cls, connection):
        cursor = connection.cursor()
        cursor.execute(
            'SELECT 1 FROM information_schema.tables '
            'WHERE table_name IN %s',
            (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
              'res_user', 'res_group', 'ir_module', 'ir_module_dependency',
              'ir_translation', 'ir_lang'), ))
        return len(cursor.fetchall()) != 0

    def nextid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')")
        return cursor.fetchone()[0]

    def setnextid(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value)

    def currid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('SELECT last_value FROM "' + table + '_id_seq"')
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table)

    def has_constraint(self):
        return True

    def has_multirow_insert(self):
        return True

    def get_table_schema(self, connection, table_name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute(
                'SELECT 1 '
                'FROM information_schema.tables '
                'WHERE table_name = %s AND table_schema = %s',
                (table_name, schema))
            if cursor.rowcount:
                return schema

    @property
    def current_user(self):
        if self._current_user is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SELECT current_user')
                self._current_user = cursor.fetchone()[0]
            finally:
                self.put_connection(connection)
        return self._current_user

    @property
    def search_path(self):
        if self._search_path is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SHOW search_path')
                path, = cursor.fetchone()
                special_values = {
                    'user': self.current_user,
                }
                self._search_path = [
                    unescape_quote(
                        replace_special_values(p.strip(), **special_values))
                    for p in path.split(',')
                ]
            finally:
                self.put_connection(connection)
        return self._search_path
Beispiel #41
0
class Database(object):
    def __init__(self,
                 db_connection,
                 connection_number,
                 app,
                 min_db_connections=1):
        self.db_connection = db_connection
        self.min_connections = min_db_connections
        self.connection_number = connection_number
        self.pool = ThreadedConnectionPool(min_db_connections,
                                           connection_number, db_connection)
        self.app = app
        self.last_seen_process_id = os.getpid()
        self.needs_change = True

    def dispose(self):
        self.pool.closeall()

    def get_connection(self):
        if self.needs_change:
            current_pid = os.getpid()
            if not (current_pid == self.last_seen_process_id):
                self.last_seen_process_id = current_pid
                self.pool.closeall()
                self.pool = ThreadedConnectionPool(self.min_connections,
                                                   self.connection_number,
                                                   self.db_connection)
            self.needs_change = False
        return self.pool.getconn()

    def run_script(self, script):
        from app import app
        success = True
        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute(script)
            conn.commit()
            self.pool.putconn(conn)
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(script))
            conn.rollback()
            self.pool.putconn(conn)
            success = False
        return success

    def add_tracks(self, dicts, table):
        conn = self.get_connection()
        cursor = conn.cursor()
        for track in tracks:

            columns = track.keys()
            values = [track[column] for column in columns]
            sql = "INSERT INTO tracks (%s) VALUES %s"
            cursor.execute(sql, (AsIs(','.join(columns)), tuple(values)))

        conn.commit()
        main_database.Database.run_script

    def execute_update(self, sql, vars=None):
        conn = self.get_connection()
        success = True
        try:
            cursor = conn.cursor()
            cursor.execute(sql, vars)
            conn.commit()
            self.pool.putconn(conn)
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            conn.rollback()
            self.pool.putconn(conn)
            success = False
        return success

    def value_exists(self, table, field, value):
        exists = False
        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            sql = "SELECT * FROM %s WHERE %s = %s "
            vars = (AsIs(table), AsIs(field), value)
            cursor.execute(sql, vars)
            results = cursor.fetchall()
            self.pool.putconn(conn)
            if len(results) > 0:
                exists = True
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            conn.rollback()
            self.pool.putconn(conn)

        return exists

    def delete_table(self, table):
        '''Deletes the table from the database
        table:
            Thr name of the table to delete 
        Returns:
            True - if the table was deleted, otherwise False.
        '''
        conn = self.get_connection()
        sql = "DROP TABLE public.{}".format(table)
        try:
            cursor = conn.cursor()
            cursor.execute(sql)
            conn.commit()
            self.pool.putconn(conn)
            return True
        except Exception as e:
            self.app.logger.exception(
                "The table {}  could not be deleted using the SQL:\n{}".format(
                    table, sql))
            conn.rollback()
            self.pool.putconn(conn)
            return False

    def remove_columns(self, table, columns):
        '''Deletes columns from the specified table
        
        Args:
            columns(list[str]): A list of column names to delete
                
        Returns:
            True - if the columns were deleted, otherwise False
        '''
        conn = self.get_connection()
        results = []
        sql = ""
        try:
            cursor = conn.cursor()
            for col in columns:

                sql = "ALTER TABLE public.{} DROP COLUMN {}".format(table, col)
                cursor.execute(sql)
            conn.commit()
            self.pool.putconn(conn)
            return True
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            conn.rollback()
            self.pool.putconn(conn)
            return False

    def add_columns(self, table, columns):
        '''Adds columns to to the specified table
        
        Args:
            columns: A list of dictionaries with 'name' and 'datatype' (text,integer or double)
                and optionally default (specifies default value). There is no check on the column
                name, so it has to be compatible with postgresql
                
        Returns:
            True - if the columns were added, otherwise False
        '''

        conn = self.get_connection()
        results = []
        sql = ""
        try:
            cursor = conn.cursor()
            for col in columns:
                if col['datatype'] == "double":
                    col['datatype'] = "double precision"
                sql = "ALTER TABLE public.{} ADD COLUMN {} {}".format(
                    table, col['name'], col['datatype'])
                if col.get("default"):
                    val = col.get("default")
                    if col['datatype'] == "text":
                        val = "'" + val + "'"
                    sql += " default " + val
                cursor.execute(sql)
            conn.commit()
            self.pool.putconn(conn)
            return True
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            conn.rollback()
            self.pool.putconn(conn)
            return False

    def execute_query(self, sql, vars=None):
        conn = self.get_connection()
        results = []
        try:
            cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
            cursor.execute(sql, vars)
            results = cursor.fetchall()
            self.pool.putconn(conn)
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            self.pool.putconn(conn)

        return results

    def get_sql(self, sql, vars=None):
        conn = self.get_connection()
        try:
            cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
            result = cursor.mogrify(sql, vars)
            self.pool.putconn(conn)
            return result
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            self.pool.putconn(conn)

    def delete_by_id(self, table, ids):
        conn = self.get_connection()
        try:
            cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
            sql = "DELETE FROM %s WHERE id = ANY (%s)"
            cursor.execute(sql, (AsIs(table), ids))
            conn.commit()
            self.pool.putconn(conn)
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            conn.rollback()
            self.pool.putconn(conn)

    def execute_delete(self, sql, vars=None):

        conn = self.get_connection()

        try:
            cursor = conn.cursor()
            cursor.execute(sql, vars)
            conn.commit()
            self.pool.putconn(conn)
        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            conn.rollback()
            self.pool.putconn(conn)
            return False
        return True

    def execute_insert(self, sql, vars=None, ret_id=True):
        if ret_id:
            sql += " RETURNING id"
        conn = self.get_connection()
        new_id = -1

        try:
            cursor = conn.cursor()
            cursor.execute(sql, vars)
            conn.commit()
            if ret_id:
                new_id = cursor.fetchone()[0]
            self.pool.putconn(conn)

        except Exception as e:
            self.app.logger.exception(
                "The SQL could not be run:\n{}".format(sql))
            conn.rollback()
            self.pool.putconn(conn)
        return new_id

    def get_tracks(self, track_ids=[], field="track_id", proxy=True):
        '''Get track info for all the supplies track ids
        Args:
            track_ids(list): A list of track_ids (or ids) to retreive
            field(Optional[str]): The field to query the databse (track_id
                by default)
            proxy(Optional[boolean]) If true (default) the the track's
              proxy url as defined in the config will be returned
        
        '''
        from app import app
        conn = self.get_connection()
        cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
        sql = "SELECT id,track_id,url,type,short_label,color FROM tracks WHERE %s = ANY(%s)"
        cursor.execute(sql, (AsIs(field), track_ids))
        results = cursor.fetchall()
        self.pool.putconn(conn)
        if not results:
            return []
        t_p = app.config.get("TRACK_PROXIES")
        if t_p and proxy:
            for item in results:
                for p in t_p:
                    item['url'] = item['url'].replace(p, t_p[p])
        return results

    def insert_dicts_into_table(self, dicts, table):

        conn = self.get_connection()
        try:
            cursor = conn.cursor()
            for item in dicts:
                sql = "INSERT INTO {} (%s) VALUES %s".format(table)
                columns = item.keys()
                values = [item[column] for column in columns]
                cursor.execute(sql, (AsIs(','.join(columns)), tuple(values)))
            conn.commit()
            cursor.close()
            self.pool.putconn(conn)
        except Exception as e:
            conn.rollback()
            self.pool.putconn(conn)
            raise

    def update_table_with_dicts(self, dicts, table):
        '''Updates the table with with the values in the supplied
        dictionaries (the keys being the column names) 
   
        Args:
            dicts (list[dict]): A list of dictionaries containing key(column name)
            to value (new value). Each dictionary should also contain an id key
            with the id of the row.
            
        Returns:
            True if the update was successful, otherwise False.
        '''
        conn = self.get_connection()
        sql = ""
        complete = True
        try:
            cursor = conn.cursor()
            for item in dicts:
                id = item['id']
                del item['id']
                sql = "UPDATE {} SET (%s) = %s WHERE id= %s".format(table)
                columns = item.keys()
                values = [item[column] for column in columns]
                cursor.execute(sql,
                               (AsIs(','.join(columns)), tuple(values), id))
            conn.commit()
            cursor.close()
        except:
            self.app.logger.exception("update could not be run")
            conn.rollback()
            complete = False
        self.pool.putconn(conn)
        return complete

    def insert_dict_into_table(self, dic, table):
        conn = self.get_connection()
        try:
            cursor = conn.cursor()

            sql = "INSERT INTO {} (%s) VALUES %s RETURNING id".format(table)
            columns = dic.keys()
            values = [dic[column] for column in columns]
            cursor.execute(sql, (AsIs(','.join(columns)), tuple(values)))
            conn.commit()
            new_id = cursor.fetchone()[0]
            cursor.close()
            self.pool.putconn(conn)
            return new_id
        except:
            self.app.logger.exception("update could not be run")
            self.pool.putconn(conn)
            return -1
Beispiel #42
0
class Database():
    def __init__(self, config, verbose_start=False):
        self.config = config

        if verbose_start:
            print('Connecting to database...'.ljust(
                LogDecorator.get_ljust_offset()),
                  end='')

        # get DB parameters
        try:
            self.database = config.getProperty('Database', 'name').lower()
            self.host = config.getProperty('Database', 'host')
            self.port = config.getProperty('Database', 'port')
            self.user = config.getProperty('Database', 'user', fallback=None)
            self.password = config.getProperty('Database',
                                               'password',
                                               fallback=None)

            if self.user is None or self.password is None:
                # load from credentials file instead
                credentials = config.getProperty('Database', 'credentials')
                with open(credentials, 'r') as c:
                    lines = c.readlines()
                    for line in lines:
                        line = line.lstrip().rstrip('\r').rstrip('\n')
                        if line.startswith('#') or line.startswith(';'):
                            continue
                        tokens = line.split('=')
                        if len(tokens) >= 2:
                            idx = line.find('=') + 1
                            field = tokens[0].strip().lower()
                            if field == 'username':
                                self.user = line[idx:]
                            elif field == 'password':
                                self.password = line[idx:]

            self.user = self.user.lower()
        except Exception as e:
            if verbose_start:
                LogDecorator.print_status('fail')
            raise Exception(
                f'Incomplete database credentials provided in configuration file (message: "{str(e)}").'
            )

        try:
            self._createConnectionPool()
        except Exception as e:
            if verbose_start:
                LogDecorator.print_status('fail')
            raise Exception(
                f'Could not connect to database (message: "{str(e)}").')

        if verbose_start:
            LogDecorator.print_status('ok')

    def _createConnectionPool(self):
        self.connectionPool = ThreadedConnectionPool(
            0,
            max(
                2,
                self.config.getProperty('Database',
                                        'max_num_connections',
                                        type=int,
                                        fallback=20)
            ),  # 2 connections are needed as minimum for retrying of execution
            host=self.host,
            database=self.database,
            port=self.port,
            user=self.user,
            password=self.password,
            connect_timeout=10)

    def canConnect(self):
        with self._get_connection() as conn:
            return conn is not None and not conn.closed

    @contextmanager
    def _get_connection(self):
        conn = self.connectionPool.getconn()
        conn.autocommit = True
        try:
            yield conn
        finally:
            self.connectionPool.putconn(conn, close=False)

    def execute(self, query, arguments, numReturn=None):
        with self._get_connection() as conn:
            cursor = conn.cursor(cursor_factory=RealDictCursor)

            # execute statement
            try:
                cursor.execute(query, arguments)
                conn.commit()
            except Exception as e:
                if not conn.closed:
                    conn.rollback()
                # self.connectionPool.putconn(conn, close=False)    #TODO: this still causes connection to close
                conn = self.connectionPool.getconn()

                # retry execution
                try:
                    cursor = conn.cursor(cursor_factory=RealDictCursor)
                    cursor.execute(query, arguments)
                    conn.commit()
                except:
                    if not conn.closed:
                        conn.rollback()
                    print(e)

            # get results
            try:
                returnValues = []
                if numReturn is None:
                    # cursor.close()
                    return

                elif numReturn == 'all':
                    returnValues = cursor.fetchall()
                    # cursor.close()
                    return returnValues

                else:
                    for _ in range(numReturn):
                        rv = cursor.fetchone()
                        if rv is None:
                            return returnValues
                        returnValues.append(rv)

                    # cursor.close()
                    return returnValues
            except Exception as e:
                print(e)

    def insert(self, query, values, numReturn=None):
        with self._get_connection() as conn:
            cursor = conn.cursor()
            try:
                execute_values(cursor, query, values)
                conn.commit()
            except Exception as e:
                if not conn.closed:
                    conn.rollback()
                # cursor.close()

                # retry execution
                conn = self.connectionPool.getconn()
                try:
                    cursor = conn.cursor(cursor_factory=RealDictCursor)
                    execute_values(cursor, query, values)
                    conn.commit()
                except Exception as e:
                    if not conn.closed:
                        conn.rollback()
                    print(e)

            # get results
            try:
                returnValues = []
                if numReturn is None:
                    # cursor.close()
                    return

                elif numReturn == 'all':
                    returnValues = cursor.fetchall()
                    # cursor.close()
                    return returnValues

                else:
                    for _ in range(numReturn):
                        rv = cursor.fetchone()
                        if rv is None:
                            return returnValues
                        returnValues.append(rv)

                    # cursor.close()
                    return returnValues
            except Exception as e:
                print(e)
Beispiel #43
0
class Database(DatabaseInterface):

    _databases = {}
    _connpool = None
    _list_cache = None
    _list_cache_timestamp = None
    _version_cache = {}
    flavor = Flavor(ilike=True)

    def __new__(cls, name='template1'):
        if name in cls._databases:
            return cls._databases[name]
        return DatabaseInterface.__new__(cls, name=name)

    def __init__(self, name='template1'):
        super(Database, self).__init__(name=name)
        self._databases.setdefault(name, self)
        self._search_path = None
        self._current_user = None

    @classmethod
    def dsn(cls, name):
        uri = parse_uri(config.get('database', 'uri'))
        assert uri.scheme == 'postgresql'
        host = uri.hostname and "host=%s" % uri.hostname or ''
        port = uri.port and "port=%s" % uri.port or ''
        name = "dbname=%s" % name
        user = uri.username and "user=%s" % uri.username or ''
        password = ("password=%s" % urllib.unquote_plus(uri.password)
            if uri.password else '')
        return '%s %s %s %s %s' % (host, port, name, user, password)

    def connect(self):
        if self._connpool is not None:
            return self
        logger.info('connect to "%s"', self.name)
        minconn = config.getint('database', 'minconn', default=1)
        maxconn = config.getint('database', 'maxconn', default=64)
        self._connpool = ThreadedConnectionPool(
            minconn, maxconn, self.dsn(self.name))
        return self

    def get_connection(self, autocommit=False, readonly=False):
        if self._connpool is None:
            self.connect()
        conn = self._connpool.getconn()
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        else:
            conn.set_isolation_level(ISOLATION_LEVEL_REPEATABLE_READ)
        if readonly:
            cursor = conn.cursor()
            cursor.execute('SET TRANSACTION READ ONLY')
        conn.cursor_factory = PerfCursor
        return conn

    def put_connection(self, connection, close=False):
        self._connpool.putconn(connection, close=close)

    def close(self):
        if self._connpool is None:
            return
        self._connpool.closeall()
        self._connpool = None

    @classmethod
    def create(cls, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('CREATE DATABASE "' + database_name + '" '
            'TEMPLATE template0 ENCODING \'unicode\'')
        connection.commit()
        cls._list_cache = None

    def drop(self, connection, database_name):
        cursor = connection.cursor()
        cursor.execute('DROP DATABASE "' + database_name + '"')
        Database._list_cache = None

    def get_version(self, connection):
        if self.name not in self._version_cache:
            cursor = connection.cursor()
            cursor.execute('SELECT version()')
            version, = cursor.fetchone()
            self._version_cache[self.name] = tuple(map(int,
                RE_VERSION.search(version).groups()))
        return self._version_cache[self.name]

    @staticmethod
    def dump(database_name):
        from trytond.tools import exec_command_pipe

        cmd = ['pg_dump', '--format=c', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            # if db_password is set in configuration we should pass
            # an environment variable PGPASSWORD to our subprocess
            # see libpg documentation
            env['PGPASSWORD'] = uri.password
        cmd.append(database_name)

        pipe = exec_command_pipe(*tuple(cmd), env=env)
        pipe.stdin.close()
        data = pipe.stdout.read()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t dump database!')
        return data

    @staticmethod
    def restore(database_name, data):
        from trytond.tools import exec_command_pipe

        database = Database().connect()
        connection = database.get_connection(autocommit=True)
        database.create(connection, database_name)
        database.close()

        cmd = ['pg_restore', '--no-owner']
        env = {}
        uri = parse_uri(config.get('database', 'uri'))
        if uri.username:
            cmd.append('--username='******'--host=' + uri.hostname)
        if uri.port:
            cmd.append('--port=' + str(uri.port))
        if uri.password:
            env['PGPASSWORD'] = uri.password
        cmd.append('--dbname=' + database_name)
        args2 = tuple(cmd)

        if os.name == "nt":
            tmpfile = (os.environ['TMP'] or 'C:\\') + os.tmpnam()
            with open(tmpfile, 'wb') as fp:
                fp.write(data)
            args2 = list(args2)
            args2.append(' ' + tmpfile)
            args2 = tuple(args2)

        pipe = exec_command_pipe(*args2, env=env)
        if not os.name == "nt":
            pipe.stdin.write(data)
        pipe.stdin.close()
        res = pipe.wait()
        if res:
            raise Exception('Couldn\'t restore database')

        database = Database(database_name).connect()
        cursor = database.get_connection().cursor()
        if not database.test():
            cursor.close()
            database.close()
            raise Exception('Couldn\'t restore database!')
        cursor.close()
        database.close()
        Database._list_cache = None
        return True

    def list(self):
        now = time.time()
        timeout = config.getint('session', 'timeout')
        res = Database._list_cache
        if res and abs(Database._list_cache_timestamp - now) < timeout:
            return res

        connection = self.get_connection()
        cursor = connection.cursor()
        cursor.execute('SELECT datname FROM pg_database '
            'WHERE datistemplate = false ORDER BY datname')
        res = []
        for db_name, in cursor:
            try:
                with connect(self.dsn(db_name)) as conn:
                    if self._test(conn):
                        res.append(db_name)
            except Exception:
                continue
        self.put_connection(connection)

        Database._list_cache = res
        Database._list_cache_timestamp = now
        return res

    def init(self):
        from trytond.modules import get_module_info

        connection = self.get_connection()
        cursor = connection.cursor()
        sql_file = os.path.join(os.path.dirname(__file__), 'init.sql')
        with open(sql_file) as fp:
            for line in fp.read().split(';'):
                if (len(line) > 0) and (not line.isspace()):
                    cursor.execute(line)

        for module in ('ir', 'res'):
            state = 'uninstalled'
            if module in ('ir', 'res'):
                state = 'to install'
            info = get_module_info(module)
            cursor.execute('SELECT NEXTVAL(\'ir_module_id_seq\')')
            module_id = cursor.fetchone()[0]
            cursor.execute('INSERT INTO ir_module '
                '(id, create_uid, create_date, name, state) '
                'VALUES (%s, %s, now(), %s, %s)',
                (module_id, 0, module, state))
            for dependency in info.get('depends', []):
                cursor.execute('INSERT INTO ir_module_dependency '
                    '(create_uid, create_date, module, name) '
                    'VALUES (%s, now(), %s, %s)',
                    (0, module_id, dependency))

        connection.commit()
        self.put_connection(connection)

    def test(self):
        connection = self.get_connection()
        is_tryton_database = self._test(connection)
        self.put_connection(connection)
        return is_tryton_database

    @classmethod
    def _test(cls, connection):
        cursor = connection.cursor()
        cursor.execute('SELECT 1 FROM information_schema.tables '
            'WHERE table_name IN %s',
            (('ir_model', 'ir_model_field', 'ir_ui_view', 'ir_ui_menu',
                    'res_user', 'res_group', 'ir_module',
                    'ir_module_dependency', 'ir_translation',
                    'ir_lang'),))
        return len(cursor.fetchall()) != 0

    def nextid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute("SELECT NEXTVAL('" + table + "_id_seq')")
        return cursor.fetchone()[0]

    def setnextid(self, connection, table, value):
        cursor = connection.cursor()
        cursor.execute("SELECT SETVAL('" + table + "_id_seq', %d)" % value)

    def currid(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('SELECT last_value FROM "' + table + '_id_seq"')
        return cursor.fetchone()[0]

    def lock(self, connection, table):
        cursor = connection.cursor()
        cursor.execute('LOCK "%s" IN EXCLUSIVE MODE NOWAIT' % table)

    def has_constraint(self):
        return True

    def has_multirow_insert(self):
        return True

    def get_table_schema(self, connection, table_name):
        cursor = connection.cursor()
        for schema in self.search_path:
            cursor.execute('SELECT 1 '
                'FROM information_schema.tables '
                'WHERE table_name = %s AND table_schema = %s',
                (table_name, schema))
            if cursor.rowcount:
                return schema

    @property
    def current_user(self):
        if self._current_user is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SELECT current_user')
                self._current_user = cursor.fetchone()[0]
            finally:
                self.put_connection(connection)
        return self._current_user

    @property
    def search_path(self):
        if self._search_path is None:
            connection = self.get_connection()
            try:
                cursor = connection.cursor()
                cursor.execute('SHOW search_path')
                path, = cursor.fetchone()
                special_values = {
                    'user': self.current_user,
                }
                self._search_path = [
                    unescape_quote(replace_special_values(
                            p.strip(), **special_values))
                    for p in path.split(',')]
            finally:
                self.put_connection(connection)
        return self._search_path