Beispiel #1
0
class PGDbHelper(object):

    def __init__(self, conn_str, echo=False):
        self.echo = echo
        self.pool = SimpleConnectionPool(1, 12, conn_str)

    def finish(self):
        self.pool.closeall()

    @contextmanager
    def _get_cursor(self):
        conn = self.pool.getconn()
        # conn.autocommit = True
        conn.set_isolation_level(0)
        try:
            if self.echo:
                cur = conn.cursor(cursor_factory=LoggingCursor)
            else:
                cur = conn.cursor(cursor_factory=DictCursor)
            yield cur
            conn.commit()
            conn.close()

        finally:
            self.pool.putconn(conn)

    def insert(self, sql_string, value):
        try:
            with self._get_cursor() as cur:
                cur.execute(sql_string, value)
        except Exception, ex:
            logger.debug("Error while executing %s" % sql_string)
            logger.debug(traceback.print_exc())
            raise ex
Beispiel #2
0
def get_conn():
    global POOL
    if not POOL:
        POOL = SimpleConnectionPool(POOL_SIZE, POOL_SIZE, "")
    conn = POOL.getconn()
    try:
        yield conn
    finally:
        POOL.putconn(conn)
Beispiel #3
0
class Database:

    pg_config: Dict = None
    CONNECTION_NAME: str = None

    def __init__(self):
        self.CONNECTION_NAME = getenv('INSTANCE_CONNECTION_NAME')
        self.pg_config = {
            'user': getenv('POSTGRES_USER', '<YOUR DB USER>'),
            'password': getenv('POSTGRES_PASSWORD', '<YOUR DB PASSWORD>'),
            'dbname': getenv('POSTGRES_DATABASE', '<YOUR DB NAME>')
        }

    pg_pool = None

    def __connect(self, host):
        """
        Helper function to connect to Postgres
        """
        self.pg_config['host'] = host
        self.pg_pool = SimpleConnectionPool(1, 1, **self.pg_config)

    def connect(self):
        try:
            self.__connect(f'/cloudsql/{self.CONNECTION_NAME}')
        except OperationalError:
            # If production settings fail, use local development ones
            self.__connect('localhost')

    def get_pool(self):
        if not self.pg_pool:
            self.connect()
        with self.pg_pool.getconn() as conn:
            return conn

    def return_pool(self, pool):
        self.pg_pool.putconn(pool)

    def postgres_demo(self):
        # Initialize the pool lazily, in case SQL access isn't needed for this
        # GCF instance. Doing so minimizes the number of active SQL connections,
        # which helps keep your GCF instances under SQL connection limits.
        if not self.pg_pool:
            self.connect()

        # Remember to close SQL resources declared while running this function.
        # Keep any declared in global scope (e.g. pg_pool) for later reuse.
        with self.pg_pool.getconn() as conn:
            cursor = conn.cursor()
            cursor.execute('SELECT NOW() as now')
            results = cursor.fetchone()
            self.pg_pool.putconn(conn)
            return str(results[0])
Beispiel #4
0
class PostgresPoolWrapper:
    def __init__(self,
                 postgres_dsn: str,
                 min_connections: int = int(os.environ["MIN_DB_CONNECTIONS"]),
                 max_connections: int = int(os.environ["MAX_DB_CONNECTIONS"])):
        self.postgres_pool: Optional[SimpleConnectionPool] = None
        self.postgres_dsn = postgres_dsn
        self.min_connections = min_connections
        self.max_connections = max_connections

    def init(self):
        """ Connects to the database and initializes connection pool """
        if self.postgres_pool is not None:
            return

        try:
            self.postgres_pool = SimpleConnectionPool(
                self.min_connections,
                self.max_connections,
                self.postgres_dsn,
                cursor_factory=RealDictCursor)

            if self.postgres_pool is None:
                raise Exception("Unknown error")

        except (Exception, psycopg2.DatabaseError) as e:
            print(f"Failed to create Postgres connection pool: {e}")

    def get_conn(self) -> Iterator[RealDictConnection]:
        """ Yields a connection from the connection pool and returns the connection to the pool
            after the yield completes
        """
        if self.postgres_pool is None:
            raise Exception(
                "Cannot get db connection before connecting to database")

        conn: RealDictConnection = self.postgres_pool.getconn()

        if conn is None:
            raise Exception(
                "Failed to get connection from Postgres connection pool")

        yield conn

        self.postgres_pool.putconn(conn)

    def cleanup(self):
        """ Closes all connections in the connection pool """
        if self.postgres_pool is None:
            return

        self.postgres_pool.closeall()
Beispiel #5
0
class Datastore(object):
    """
    the datastore interface
    """
    
    def __init__(self, dbname='web', dbuser='******', dbpassw='Klofcumad1'):
        self.db = dbname
        self.dbuser = dbuser
        self.dbpassw = dbpassw
        self.pool = SimpleConnectionPool(1, 100, database=dbname, user=dbuser, password=dbpassw)
        
    def start_op(self):
        conn = self.pool.getconn()
        cur = conn.cursor()
        return (conn, cur)
    
    def close_op(self, conn):
        conn.commit()
        self.pool.putconn(conn)        

    def find_user(self, id=None, email=None):        
        assert(id != None or email != None)
        conn, cur = self.start_op()
        if id != None:
            cur.execute("SELECT * FROM users WHERE id=%s", [id])
        if email != None:
            cur.execute("SELECT * FROM users WHERE email=%s", [email])
        row = cur.fetchone()
        print(row)
        if row != None:
            return User(id=row[1], email=row[0], passw=row[2])
        return None
    
    def add_user(self, email, password):
        assert(email != None and password != None)
        conn, cur = self.start_op()
        id = uuid()
        cur.execute("INSERT INTO users (id, email, password) VALUES (%s, %s, %s)", [str(id), email, generate_password_hash(password)])
        self.close_op(conn)
        return id
Beispiel #6
0
    class __SQLCommand:
        def __init__(self):
            # pool define with 10 live connections
            self.connectionpool = SimpleConnectionPool(1,10,dsn=CONNECTION_STRING)
        
        @contextmanager
        def getcursor(self):
            con = self.connectionpool.getconn()
            try:
                con.autocommit = True
                yield con.cursor()
            finally:
                self.connectionpool.putconn(con)

        @contextmanager
        def getconnection(self, autocommit = True):
            con = self.connectionpool.getconn()
            try:
                con.autocommit = autocommit
                yield con
            finally:
                self.connectionpool.putconn(con)
Beispiel #7
0
class DBService:
    def __init__(self, dbconfig):
        self.dbconfig = dbconfig
        self.db = SimpleConnectionPool(
            minconn=self.dbconfig['minconn'], 
            maxconn=self.dbconfig['maxconn'],
            database=self.dbconfig['dbname'], 
            user=self.dbconfig['username'], 
            host=self.dbconfig['host'], 
            port=self.dbconfig['port'], 
            password=self.dbconfig['password'])

    @contextmanager
    def get_cursor(self):
        con = self.db.getconn()
        try:
            yield con.cursor()
            con.commit()
        finally:
            self.db.putconn(con)

    def update(self, sql):
        with self.get_cursor() as cursor:
            logging.debug(u'Executing sql: [{0}]'.format(sql.strip()))
            cursor.execute(sql)

    def query(self, sql, rowCallback):
        with self.get_cursor() as cursor:
            logging.debug(u'Executing sql: [{0}]'.format(sql.strip()))
            cursor.execute(sql) 
            for row in cursor:
                rowCallback(row)

    def query_single(self, sql):
        with self.get_cursor() as cursor:
            logging.debug(u'Executing sql: [{0}]'.format(sql.strip()))
            cursor.execute(sql) 
            yield cursor.fetchone()
Beispiel #8
0
class DBService:
    def __init__(self, dbconfig):
        self.dbconfig = dbconfig
        self.db = SimpleConnectionPool(minconn=self.dbconfig['minconn'],
                                       maxconn=self.dbconfig['maxconn'],
                                       database=self.dbconfig['dbname'],
                                       user=self.dbconfig['username'],
                                       host=self.dbconfig['host'],
                                       port=self.dbconfig['port'],
                                       password=self.dbconfig['password'])

    @contextmanager
    def get_cursor(self):
        con = self.db.getconn()
        try:
            yield con.cursor()
            con.commit()
        finally:
            self.db.putconn(con)

    def update(self, sql):
        with self.get_cursor() as cursor:
            logging.debug(u'Executing sql: [{0}]'.format(sql.strip()))
            cursor.execute(sql)

    def query(self, sql, rowCallback):
        with self.get_cursor() as cursor:
            logging.debug(u'Executing sql: [{0}]'.format(sql.strip()))
            cursor.execute(sql)
            for row in cursor:
                rowCallback(row)

    def query_single(self, sql):
        with self.get_cursor() as cursor:
            logging.debug(u'Executing sql: [{0}]'.format(sql.strip()))
            cursor.execute(sql)
            yield cursor.fetchone()
Beispiel #9
0
class PostgreSqlBase:
    _pool = None

    def __init__(self):
        self._load_config()
        if self.config:
            self._validate_database()
            self._pool = SimpleConnectionPool(**self.config)
            self._db = self.session_factory()

    @property
    def db(self) -> Session:
        return self._db

    def cleanup(self, exception=None):
        if exception:
            self.db.rollback()
            return
        self.db.commit()

    def ping(self):
        return self._get_alive_connection()

    def fetch_row(self, query, params={}):
        conn = self._get_alive_connection()
        try:
            cur = conn.cursor()
            cur.execute(query, params)
            row = cur.fetchone()
            if row:
                col_names = map(lambda item: item[0], cur.description)
                data = dict(zip(col_names, row))
                return data
            return None
        finally:
            if conn:
                self._pool.putconn(conn)

    def fetch_rows(self, query, params={}):
        conn = self._get_alive_connection()
        try:
            cur = conn.cursor(cursor_factory=RealDictCursor)
            if params:
                cur.execute(query, params)
            else:
                cur.execute(query)
            rows = cur.fetchall()
            if rows:
                arr = []
                for row in rows:
                    arr.append(dict(row))
                return arr
            return None
        except Exception as e:
            raise Exception("Fail to fetch from database")
        finally:
            if conn:
                self._pool.putconn(conn)

    def exec_transaction(self, query, param={}, has_return=False):
        conn = self._get_alive_connection()
        try:
            cur = conn.cursor()
            cur.execute(query, param)
            conn.commit()
            if not has_return:
                return True
            else:
                rows = cur.fetchall()
                return rows
        except Exception as e:
            raise Exception("Database transaction fail.")
        finally:
            if conn:
                self._pool.putconn(conn)

    def _load_config(self):
        try:
            self.config = get_env_config()['database']
        except Exception as e:
            logger.error(e)
            raise Exception(e)

    def _validate_database(self):
        self.engine = create_engine(
            f"postgres://{self.config['user']}:{self.config['password']}@{self.config['host']}/{self.config['database']}"
        )
        if not database_exists(self.engine.url):
            create_database(self.engine.url)
        self.session_factory = scoped_session(sessionmaker(bind=self.engine))

    def _get_alive_connection(self):
        max_retry_count = 10
        while True:
            conn = None
            try:
                conn = self._pool.getconn()
                cursor = conn.cursor()
                cursor.execute('SELECT 1')
                cursor.fetchall()
                conn.commit()
                return conn
            except Exception as e:
                logger.error(f"Fail to fetch from database - {e}")
                time.sleep(1)
                if conn:
                    self._pool.putconn(conn)
                max_retry_count -= 1
                if max_retry_count <= 0:
                    raise Exception("Maximum retries reached")

    def close_connection(self):
        conn = self._get_alive_connection()
        if conn:
            self._pool.putconn(conn)
Beispiel #10
0
class DB(object):
    def __init__(self, config):
        try:
            self.DataName = config['datatype']
            del config['datatype']
        except:
            self.DataName = 'MYSQL'

        if self.DataName == 'MYSQL':
            try:
                self.pool = mysql.connector.pooling.MySQLConnectionPool(
                    **config)
                self.cnx = self.cur = None
            except mysql.connector.Error as err:
                # 这里需要记录操作日志
                logging.debug(err.msg)
                self.cnx = None
                raise BaseError(701)  # 与数据库连接异常
        elif self.DataName == 'POSTGRESQL':
            try:
                self.pool = SimpleConnectionPool(**config)
            except:
                raise BaseError(701)  # 与数据库连接异常

        elif self.DataName == 'ORACLE':
            try:
                if config['NLS_LANG']:
                    os.environ['NLS_LANG'] = config['NLS_LANG']
                del config['NLS_LANG']
            except:
                pass

            try:
                self.pool = cx_Oracle.SessionPool(**config)
            except:
                raise BaseError(701)  # 与数据库连接异常

    def open(self):
        try:
            if self.DataName == 'ORACLE':
                self.__conn = self.pool.acquire()
                self.__cursor = self.__conn.cursor()
            elif self.DataName == 'POSTGRESQL':
                self.__conn = self.pool.getconn()
                self.__cursor = self.__conn.cursor()
            else:  # 默认为Mysql
                self.__conn = self.pool.get_connection()
                self.__cursor = self.__conn.cursor(buffered=True)

            #self.__conn.autocommit=True
            self.__conn.autocommit = False
            self.cnx = self.__conn
            self.cur = self.__cursor
        except:
            raise BaseError(702)  # 无法获得连接池

    def close(self):
        #关闭游标和数据库连接
        self.__conn.commit()
        if self.__cursor is not None:
            self.__cursor.close()

        if self.DataName == 'POSTGRESQL':
            self.pool.putconn(self.__conn)  #将数据库连接放回连接池中
        else:
            self.__conn.close()

    def begin(self):
        self.__conn.autocommit = False

    def commit(self):
        self.__conn.commit()

    def rollback(self):
        self.__conn.rollback()

#---------------------------------------------------------------------------

    def findBySql(self, sql, params={}, limit=0, join='AND', lock=False):
        """
            自定义sql语句查找
            limit = 是否需要返回多少行
            params = dict(field=value)
            join = 'AND | OR'
        """
        try:
            cursor = self.__getCursor()
            sql = self.__joinWhere(sql, params, join)
            cursor.execute(sql, tuple(params.values()))
            rows = cursor.fetchmany(
                size=limit) if limit > 0 else cursor.fetchall()
            result = [dict(zip(cursor.column_names, row))
                      for row in rows] if rows else None
            return result
        except:
            raise BaseError(706)

    def countBySql(self, sql, params={}, join='AND'):
        # 自定义sql 统计影响行数
        try:
            cursor = self.__getCursor()
            sql = self.__joinWhere(sql, params, join)
            cursor.execute(sql, tuple(params.values()))
            result = cursor.fetchone()
            return result[0] if result else 0
        except:
            raise BaseError(707)

    #def updateByPk(self,table,data,id,pk='id'):
    #    # 根据主键更新,默认是id为主键
    #    return self.updateByAttr(table,data,{pk:id})

    def deleteByAttr(self, table, params={}, join='AND'):
        # 删除数据
        try:
            fields = ','.join(k + '=%s' for k in params.keys())
            sql = "DELETE FROM `%s` " % table
            sql = self.__joinWhere(sql, params, join)
            cursor = self.__getCursor()
            cursor.execute(sql, tuple(params.values()))
            self.__conn.commit()
            return cursor.rowcount

        #except:
        #    raise BaseError(704)
        except Exception as err:
            raise BaseError(704, err._full_msg)

    def deleteByPk(self, table, id, pk='id'):
        # 根据主键删除,默认是id为主键
        return self.deleteByAttr(table, {pk: id})

    def findByAttr(self, table, criteria={}):
        # 根据条件查找一条记录
        return self.__query(table, criteria)

    def findByPk(self, table, id, pk='id'):
        return self.findByAttr(table, {'where': pk + '=' + str(id)})

    def findAllByAttr(self, table, criteria={}):
        # 根据条件查找记录
        return self.__query(table, criteria, True)

    def exit(self, table, params={}, join='AND'):
        # 判断是否存在
        return self.count(table, params, join) > 0

# 公共的方法 -------------------------------------------------------------------------------------

    def count(self, table, params={}, join='AND'):
        # 根据条件统计行数
        try:
            sql = 'SELECT COUNT(*) FROM %s' % table

            if params:
                where, whereValues = self.__contact_where(params)
                sqlWhere = ' WHERE ' + where if where else ''
                sql += sqlWhere

            #sql = self.__joinWhere(sql,params,join)
            cursor = self.__getCursor()

            self.__display_Debug_IO(sql, tuple(whereValues))  #DEBUG

            if self.DataName == 'ORACLE':
                cursor.execute(sql % tuple(whereValues))
            else:
                cursor.execute(sql, tuple(whereValues))
            #cursor.execute(sql,tuple(params.values()))
            result = cursor.fetchone()
            return result[0] if result else 0
        #except:
        #    raise BaseError(707)
        except Exception as err:
            try:
                raise BaseError(707, err._full_msg)
            except:
                raise BaseError(707)

    def getToListByPk(self, table, criteria={}, id=None, pk='id'):
        # 根据条件查找记录返回List
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria, isDict=False)

    def getAllToList(self, table, criteria={}, id=None, pk='id', join='AND'):
        # 根据条件查找记录返回List
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria, all=True, isDict=False)

    def getToObjectByPk(self, table, criteria={}, id=None, pk='id'):
        # 根据条件查找记录返回Object
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria)

    def getAllToObject(self, table, criteria={}, id=None, pk='id', join='AND'):
        # 根据条件查找记录返回Object
        if ('where' not in criteria) and (id is not None):
            criteria['where'] = pk + "='" + str(id) + "'"
        return self.__query(table, criteria, all=True)

    def insert(self, table, data, commit=True):
        # 新增一条记录
        try:
            ''' 
                从data中分离含用SQL函数的字字段到funData字典中,
                不含SQL函数的字段到newData
            '''
            funData, newData = self.__split_expression(data)

            funFields = ''
            funValues = ''

            # 拼不含SQL函数的字段及值
            fields = ','.join(k for k in newData.keys())
            values = ','.join(("%s", ) * len(newData))

            # 拼含SQL函数的字段及值
            if funData:
                funFields = ','.join(k for k in funData.keys())
                funValues = ','.join(v for v in funData.values())

            # 合并所有字段及值
            fields += ',' + funFields if funFields else ''
            values += ',' + funValues if funValues else ''
            sql = 'INSERT INTO %s (%s) VALUES (%s)' % (table, fields, values)
            cursor = self.__getCursor()

            for (k, v) in newData.items():
                try:
                    if isinstance(v, str):
                        newData[k] = "'%s'" % (v, )
                except:
                    pass

            self.__display_Debug_IO(sql, tuple(newData.values()))  #DEBUG
            sql = sql % tuple(newData.values())

            if self.DataName == 'POSTGRESQL':
                sql += ' RETURNING id'

            cursor.execute(sql)

            #if self.DataName=='ORACLE':
            #sql= sql % tuple(newData.values())
            #cursor.execute(sql)
            #else :
            #cursor.execute(sql,tuple(newData.values()))

            if self.DataName == 'ORACLE':
                # 1. commit 一定要为假
                # 2. Oracle Sequence 的命名规范为: [用户名.]SEQ_表名_ID
                # 3. 每张主表都应该有ID
                t_list = table.split('.')
                if len(t_list) > 1:
                    SEQ_Name = t_list[0] + '.SEQ_' + t_list[1] + '_ID'
                else:
                    SEQ_Name = 'SEQ_' + t_list[0] + '_ID'

                cursor.execute('SELECT %s.CURRVAL FROM dual' %
                               SEQ_Name.upper())

                result = cursor.fetchone()
                insert_id = result[0] if result else 0
                #insert_id=cursor.rowcount
            elif self.DataName == 'MYSQL':
                insert_id = cursor.lastrowid
            elif self.DataName == 'POSTGRESQL':
                item = cursor.fetchone()
                insert_id = item[0]

            if commit: self.commit()
            return insert_id

        except Exception as err:
            try:
                raise BaseError(705, err._full_msg)
            except:
                raise BaseError(705, err.args)

    def update(self,
               table,
               data,
               params={},
               join='AND',
               commit=True,
               lock=True):
        # 更新数据
        try:
            fields, values = self.__contact_fields(data)
            if params:
                where, whereValues = self.__contact_where(params)

            values.extend(whereValues) if whereValues else values

            sqlWhere = ' WHERE ' + where if where else ''

            cursor = self.__getCursor()

            if commit: self.begin()

            if lock:
                sqlSelect = "SELECT %s From %s %s for update" % (','.join(
                    tuple(list(params.keys()))), table, sqlWhere)
                sqlSelect = sqlSelect % tuple(whereValues)
                cursor.execute(sqlSelect)  # 加行锁
                #cursor.execute(sqlSelect,tuple(whereValues))  # 加行锁

            sqlUpdate = "UPDATE %s SET %s " % (table, fields) + sqlWhere

            for index, val in enumerate(values):
                if isinstance(val, str):
                    values[index] = "'" + val + "'"

            self.__display_Debug_IO(sqlUpdate, tuple(values))  #DEBUG
            sqlUpdate = sqlUpdate % tuple(values)
            cursor.execute(sqlUpdate)

            #cursor.execute(sqlUpdate,tuple(values))

            if commit: self.commit()

            return cursor.rowcount

        except Exception as err:
            try:
                raise BaseError(705, err._full_msg)
            except:
                raise BaseError(705, err.args)

    def updateByPk(self, table, data, id, pk='id', commit=True, lock=True):
        # 根据主键更新,默认是id为主键
        return self.update(table, data, {pk: id}, commit=commit, lock=lock)

    def delete(self, table, params={}, join='AND', commit=True, lock=True):
        # 更新数据
        try:
            data = {}
            fields, values = self.__contact_fields(data)
            if params:
                where, whereValues = self.__contact_where(params)

            values.extend(whereValues) if whereValues else values

            sqlWhere = ' WHERE ' + where if where else ''

            cursor = self.__getCursor()

            if commit: self.begin()

            #if lock :
            #sqlSelect="SELECT %s From %s %s for update" % (','.join(tuple(list(params.keys()))),table,sqlWhere)
            #sqlSelect=sqlSelect % tuple(whereValues)
            #cursor.execute(sqlSelect)  # 加行锁
            ##cursor.execute(sqlSelect,tuple(whereValues))  # 加行锁

            sqlDelete = "DELETE FROM %s %s" % (table, sqlWhere)

            for index, val in enumerate(values):
                if isinstance(val, str):
                    values[index] = "'" + val + "'"

            self.__display_Debug_IO(sqlDelete, tuple(values))  #DEBUG
            sqlDelete = sqlDelete % tuple(values)
            cursor.execute(sqlDelete)

            #cursor.execute(sqlUpdate,tuple(values))

            if commit: self.commit()

            return cursor.rowcount

        except Exception as err:
            try:
                raise BaseError(705, err._full_msg)
            except:
                raise BaseError(705, err.args)

    def deleteByPk(self, table, id, pk='id', commit=True, lock=True):
        # 根据主键更新,默认是id为主键
        return self.delete(table, {pk: id}, commit=commit, lock=lock)

# 内部私有的方法 -------------------------------------------------------------------------------------

    def __display_Debug_IO(self, sql, params):
        if DEBUG:
            debug_now_time = datetime.datetime.now().strftime(
                '%Y-%m-%d %H:%M:%S')
            print('[S ' + debug_now_time + ' SQL:] ' +
                  (sql % params) if params else sql)

    def __get_connection(self):
        return self.pool.get_connection()

    def __getCursor(self):
        """获取游标"""
        if self.__cursor is None:
            self.__cursor = self.__conn.cursor()
        return self.__cursor

    def getCursor(self):
        """获取游标"""
        if self.__cursor is None:
            self.__cursor = self.__conn.cursor()
        return self.__cursor

    def __joinWhere(self, sql, params, join):
        # 转换params为where连接语句
        if params:

            funParams = {}
            newParams = {}
            newWhere = ''
            funWhere = ''

            # 从params中分离含用SQL函数的字字段到Params字典中
            for (k, v) in params.items():
                if 'str' in str(type(v)) and '{{' == v[:2] and '}}' == v[-2:]:
                    funParams[k] = v[2:-2]
                else:
                    newParams[k] = v

            # 拼 newParams 条件
            keys, _keys = self.__tParams(newParams)
            newWhere = ' AND '.join(k + '=' + _k for k, _k in zip(
                keys, _keys)) if join == 'AND' else ' OR '.join(
                    k + '=' + _k for k, _k in zip(keys, _keys))

            # 拼 funParams 条件
            if funParams:
                funWhere = ' AND '.join(
                    k + '=' + v for k, v in
                    funParams.items()) if join == 'AND' else ' OR '.join(
                        k + '=' + v for k, v in funParams.items())

            # 拼最终的 where
            where = (
                (newWhere + ' AND ' if newWhere else '') +
                funWhere if funWhere else newWhere) if join == 'AND' else (
                    (newWhere + ' OR ' if newWhere else '') +
                    funWhere if funWhere else newWhere)

            #--------------------------------------
            #keys,_keys = self.__tParams(params)
            #where = ' AND '.join(k+'='+_k for k,_k in zip(keys,_keys)) if join == 'AND' else ' OR '.join(k+'='+_k for k,_k in zip(keys,_keys))
            sql += ' WHERE ' + where
        return sql

    def __tParams(self, params):
        keys = [k if k[:2] != '{{' else k[2:-2] for k in params.keys()]
        _keys = ['%s' for k in params.keys()]
        return keys, _keys

    def __query(self, table, criteria, all=False, isDict=True, join='AND'):
        '''
           table    : 表名
           criteria : 查询条件dict
           all      : 是否返回所有数据,默认为False只返回一条数据,当为真是返回所有数据
           isDict   : 返回格式是否为字典,默认为True ,即字典否则为数组 
        '''
        try:
            if all is not True:
                criteria['limit'] = 1  # 只输出一条
            sql, params = self.__contact_sql(table, criteria,
                                             join)  #拼sql及params
            '''
            # 当Where为多个查询条件时,拼查询条件 key 的 valuse 值
            if 'where' in criteria and 'dict' in str(type(criteria['where'])) :
                params = criteria['where']
                #params = tuple(params.values())
                where ,whereValues   = self.__contact_where(params)
                sql+= ' WHERE '+where if where else ''
                params=tuple(whereValues)
            else :
                params = None
            '''
            #__contact_where(params,join='AND')
            cursor = self.__getCursor()

            self.__display_Debug_IO(sql, params)  #DEBUG

            #if self.DataName=="ORACLE":
            #sql="select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % ()
            #pass

            cursor.execute(sql, params if params else ())

            rows = cursor.fetchall() if all else cursor.fetchone()

            if isDict:
                result = [dict(zip(cursor.column_names, row))
                          for row in rows] if all else dict(
                              zip(cursor.column_names, rows)) if rows else {}
            else:
                result = [row for row in rows] if all else rows if rows else []
            return result
        except Exception as err:
            try:
                raise BaseError(706, err._full_msg)
            except:
                raise BaseError(706)

    def __contact_sql(self, table, criteria, join='AND'):
        sql = 'SELECT '
        if criteria and type(criteria) is dict:
            #select fields
            if 'select' in criteria:
                fields = criteria['select'].split(',')
                sql += ','.join(
                    field.strip()[2:-2] if '{{' == field.strip()[:2]
                    and '}}' == field.strip()[-2:] else field.strip()
                    for field in fields)
            else:
                sql += ' * '
            #table
            sql += ' FROM %s' % table

            #where
            whereValues = None
            if 'where' in criteria:
                if 'str' in str(type(criteria['where'])):  # 当值为String时,即单一Key时
                    sql += ' WHERE ' + criteria['where']
                else:  # 当值为dict时,即一组key时
                    params = criteria['where']
                    #sql+= self.__joinWhere('',params,join)
                    #sql+=self.__contact_where(params,join)
                    where, whereValues = self.__contact_where(params)
                    sql += ' WHERE ' + where if where else ''
                    #sql=sql % tuple(whereValues)

            #group by
            if 'group' in criteria:
                sql += ' GROUP BY ' + criteria['group']
            #having
            if 'having' in criteria:
                sql += ' HAVING ' + criteria['having']

            if self.DataName == 'MYSQL':
                #order by
                if 'order' in criteria:
                    sql += ' ORDER BY ' + criteria['order']
                #limit
                if 'limit' in criteria:
                    sql += ' LIMIT ' + str(criteria['limit'])
                #offset
                if 'offset' in criteria:
                    sql += ' OFFSET ' + str(criteria['offset'])
            elif (self.DataName == 'POSTGRESQL'):
                #order by
                if 'order' in criteria:
                    sql += ' ORDER BY ' + criteria['order']
                if 'limit' in criteria:
                    # 取 offset,rowcount
                    arrLimit = (str(
                        criteria['limit']).split('limit ').pop()).split(',')
                    strOffset = arrLimit[0]
                    try:
                        strRowcount = arrLimit[1]
                    except:
                        strOffset = '0'
                        strRowcount = '1'
                    sql += '  LIMIT %s OFFSET %s' % (strRowcount, strOffset)

            elif (self.DataName == 'ORACLE') and ('limit' in criteria):
                # 取 offset,rowcount
                arrLimit = (str(
                    criteria['limit']).split('limit ').pop()).split(',')
                strOffset = arrLimit[0]
                try:
                    strRowcount = arrLimit[1]
                except:
                    strOffset = '0'
                    strRowcount = '1'

                # 处理 order by
                if 'order' in criteria:
                    strOrder = criteria['order']
                else:
                    strOrder = 'ROWNUM'
                # 以下Sql是针对 Oracle 的大数据查询效率
                sql = "select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % (
                    strOrder, sql, strOffset, strRowcount)
            elif (self.DataName == 'ORACLE') and ('order' in criteria):
                sql += ' ORDER BY ' + criteria['order']

        else:
            sql += ' * FROM %s' % table

        return sql, whereValues

    # 将字符串和表达式分离
    def __split_expression(self, data):
        funData = {}
        newData = {}
        funFields = ''

        # 从data中移出含用SQL函数的字字段到funData字典中
        for (k, v) in data.items():
            if 'str' in str(type(v)) and '{{' == v[:2] and '}}' == v[-2:]:
                funData[k] = v[2:-2]
            else:
                newData[k] = v

        return (funData, newData)

    # 拼Update字段
    def __contact_fields(self, data):

        funData, newData = self.__split_expression(data)
        if funData:
            funFields = ','.join(k + '=%s' % (v) for k, v in funData.items())
        fields = ','.join(k + '=%s' for k in newData.keys())

        # fields 与 funFields 合并
        if funData:
            fields = ','.join([fields, funFields]) if fields else funFields

        values = list(newData.values())

        return (fields, values)

    def __hasKeyword(self, key):
        if '{{}}' in key: return True
        if 'in (' in key: return True
        if 'like ' in key: return True
        if '>' in key: return True
        if '<' in key: return True
        return False

    # 拼Where条件
    def __contact_where(self, params, join='AND'):
        funParams, newParams = self.__split_expression(params)

        # 拼 newParams 条件
        keys, _keys = self.__tParams(newParams)
        newWhere = ' AND '.join(
            k + '=' + _k
            for k, _k in zip(keys, _keys)) if join == 'AND' else ' OR '.join(
                k + '=' + _k for k, _k in zip(keys, _keys))
        values = list(newParams.values())

        # 拼 funParams 条件
        #funWhere = ' AND '.join(('`' if k else '') +k+('`' if k else '')+ (' ' if self.__hasKeyword(v) else '=') +v for k,v in funParams.items()) if join == 'AND' else ' OR '.join('`'+k+'`'+(' ' if self.__hasKeyword(v) else '=')+v for k,v in funParams.items())

        funWhere = ' AND '.join(
            k + (' ' if self.__hasKeyword(v) else '=' if k else '') + v
            for k, v in funParams.items()) if join == 'AND' else ' OR '.join(
                k + (' ' if self.__hasKeyword(v) else '=' if k else '') + v
                for k, v in funParams.items())

        # 拼最终的 where
        where = ((newWhere + ' AND ' if newWhere else '') +
                 funWhere if funWhere else newWhere) if join == 'AND' else (
                     (newWhere + ' OR ' if newWhere else '') +
                     funWhere if funWhere else newWhere)
        return (where, values)

    def get_ids(self, list):  #从getAllToList返回中提取id
        try:
            test = list[0][0]
            dimension = 2
        except:
            dimension = 1

        ids = []
        if dimension > 1:
            for i in range(len(list)):
                ids.append(str(list[i][0]))
        else:
            for i in range(len(list)):
                ids.append(str(list[i]))

        return ','.join(ids)
class PostgresqlStorage(AbstractStorage):

    connectionpool = None

    def __init__(self, mode):
        print("Ignoring server mode requested: %s" % mode)
        # pool define with 100 live connections
        self.connectionpool = SimpleConnectionPool(1, 100, dsn=self.config())

    def get_id_by_name(self, id_rec):
        return self.get_data(id_rec,
                             'select data from provenance where rec_id = %s')

    def get_abci_query(self, id_rec):
        return self.get_data(id_rec,
                             'select data from provenance where id = %s')

    def get_data(self, id_rec, sql):
        # deve retornar o conteúdo json
        with self.getConn() as conn:
            cur = conn.cursor()
            cur.execute(sql, (id_rec, ))
            # display the PostgreSQL database server version
            return cur.fetchone()[0]

    def store(self, rec_id, content):
        # deve retornar: id gerado, time da operação, hash do dado, height do bloco
        cc_content = get_content_data(content)
        with self.getConn() as conn:
            cur = conn.cursor()
            data = datetime.now()
            data_hash = hashlib.sha256(cc_content.encode('utf-8')).hexdigest()
            cur.execute(
                'insert into provenance(rec_id, data, date, hash) values (%s, %s, %s, %s) RETURNING id',
                (rec_id, cc_content, data, data_hash))
            gen_id = cur.fetchone()[0]
            conn.commit()
            return gen_id, data.__str__(), data_hash, gen_id

    def raw_store(self, id, content):
        pass

    def config(self, filename='database.ini', section='postgresql'):
        parser = ConfigParser()
        parser.read(filename)

        # get section, default to postgresql
        db = ""
        if parser.has_section(section):
            params = parser.items(section)
            for param in params:
                db += "%s='%s' " % (param[0], param[1])
        else:
            raise Exception('Section {0} not found in the {1} file'.format(
                section, filename))

        return db

    @contextmanager
    def getConn(self):
        con = self.connectionpool.getconn()
        try:
            yield con
        finally:
            self.connectionpool.putconn(con)
class PostgresConnector:

    instance = None

    data_type_mapping = {
        'INT64': 'BIGINT',
        'BOOL': 'BOOLEAN',
    }

    @classmethod
    def get_instance(cls) -> PostgresConnector:
        if cls.instance is None:
            cls.instance = cls()
        return cls.instance

    def __init__(self, config):
        self.logger = logging.getLogger('postgres_connector')
        try:
            self.pool = SimpleConnectionPool(1, 10, **config)
            self.schema = 'entity_lookup'
        except Exception as ex:
            self.logger.exception(
                'Exception occurred while connecting to the database')
            raise ex

    @contextmanager
    def _transaction(self, cursor_factory=None):
        conn = self.pool.getconn()
        try:
            yield conn, conn.cursor(cursor_factory=cursor_factory)
            conn.commit()
        except Exception as ex:
            self.logger.exception(
                'Exception occurred during database transaction')
            conn.rollback()
            raise ex
        finally:
            self.pool.putconn(conn)

    def create_schema_if_not_exists(self):
        with self._transaction() as (_, cursor):
            cursor.execute(
                sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(
                    sql.Identifier(self.schema)))

    def create_log_table_if_not_exists(self):
        with self._transaction() as (_, cursor):
            cursor.execute(
                sql.SQL("""
                    CREATE TABLE IF NOT EXISTS {schema}.jobs (
                        "id" SERIAL PRIMARY KEY,
                        "started" TIMESTAMP WITHOUT TIME ZONE,
                        "ended" TIMESTAMP WITHOUT TIME ZONE,
                        "status" VARCHAR(255),
                        "status_msg" TEXT,
                        "entity_names" CHARACTER VARYING[],
                        "feature_table" VARCHAR(255),
                        "path" TEXT
                    )
                    """).format(schema=sql.Identifier(self.schema)))

    def _create_entity_table_if_not_exists(self, table_name: str,
                                           entity_type: str, event_ts: str,
                                           created_ts: str):
        with self._transaction() as (_, cursor):
            query = sql.SQL("""
                CREATE TABLE IF NOT EXISTS {table} (
                    "id" {entity_type},
                    "feature_table" VARCHAR(255),
                    {event_ts} TIMESTAMP WITHOUT TIME ZONE,
                    {created_ts} TIMESTAMP WITHOUT TIME ZONE,
                    "path" TEXT,
                    PRIMARY KEY({cols})
                );
                """).format(
                table=sql.Identifier(self.schema, table_name),
                entity_type=sql.SQL(entity_type),
                event_ts=sql.Identifier(event_ts),
                created_ts=sql.Identifier(created_ts),
                cols=sql.SQL(', ').join(
                    map(sql.Identifier,
                        ['id', 'feature_table', event_ts, created_ts])),
            )
            cursor.execute(query)

    def create_entity_tables_if_not_exist(self, column_data: dict,
                                          parquet_path: str):
        created_ts = column_data['created_timestamp_column']
        event_ts = column_data['timestamp_column']
        table_names = {}
        for entity_name, entity_type in zip(column_data['entity_names'],
                                            column_data['entity_types']):
            table_name = f'entity_{entity_name}'
            self._create_entity_table_if_not_exists(
                table_name=table_name,
                entity_type=self.data_type_mapping[entity_type.upper()],
                event_ts=event_ts,
                created_ts=created_ts,
            )
            table_names[table_name] = entity_name
        return table_names

    def create_view_if_not_exists(self, table_names_for_entities,
                                  column_data: dict):
        for table_name in table_names_for_entities.keys():
            with self._transaction() as (_, cursor):
                query = sql.SQL("""
                    DO $$
                    BEGIN
                        CREATE VIEW {schema}.{view_name} AS
                            WITH groups AS (
                                SELECT id, feature_table, MAX({event_ts}) as {event_ts}, {created_ts}, path,
                                ROW_NUMBER() OVER(PARTITION BY id ORDER BY {created_ts} DESC, path DESC) AS rk
                                FROM {schema}.{entity_table}
                                GROUP BY id, feature_table, {created_ts}, path
                            )
                            SELECT * FROM groups WHERE rk = 1;
                    EXCEPTION
                    WHEN SQLSTATE '42P07' THEN
                        NULL;
                    END; $$
                    """).format(
                    view_name=sql.Identifier(f'max_{table_name}'),
                    schema=sql.Identifier(self.schema),
                    event_ts=sql.Identifier(column_data['timestamp_column']),
                    created_ts=sql.Identifier(
                        column_data['created_timestamp_column']),
                    entity_table=sql.Identifier(table_name),
                )
                cursor.execute(query)

    def get_columns(self, path_extract: str):
        with self._transaction(RealDictCursor) as (_, cursor):
            query = sql.SQL("""-- Reduce to one data source
                WITH data_source_ltd AS (
                    select max(ds.id), en.name as entity_name, en.type as entity_type, ft.name as feature_table, ds.timestamp_column, ds.created_timestamp_column from public.data_sources ds
                    JOIN public.feature_tables ft ON ds.id = ft.batch_source_id
                    JOIN public.feature_tables_entities_v2 fte ON ft.id = fte.feature_table_id
                    JOIN public.entities_v2 en ON fte.entity_v2_id = en.id
                    JOIN public.projects pr ON ft.project_name = pr.name
                    where ds.config::json ->> 'file_url' like {path}
                    and ft.is_deleted = false
                    and pr.archived = false
                    GROUP BY en.name, en.type, ft.name, ds.timestamp_column, ds.created_timestamp_column
                )
                -- Reduce to one row
                SELECT array_agg(entity_name) as entity_names, array_agg(entity_type) as entity_types, feature_table, timestamp_column, created_timestamp_column FROM data_source_ltd
                GROUP BY feature_table, timestamp_column, created_timestamp_column;
                """).format(path=sql.Literal(f'%{path_extract}%'))
            cursor.execute(query)
            return cursor.fetchone()

    def copy_into_table(self, table_names_for_entities: dict,
                        df: pd.DataFrame):
        for table_name, entity_name in table_names_for_entities.items():
            columns = [
                col for col in df.columns
                if col not in table_names_for_entities.keys()
                and col != table_name
            ]
            df_view = df[columns]
            columns[columns.index(entity_name)] = 'id'
            with self._transaction() as (_, cursor):
                s = StringIO()
                df_view.to_csv(s, header=False, index=False, sep='\t')
                s.seek(0)
                cursor.copy_from(s,
                                 f'{self.schema}.{table_name}',
                                 sep='\t',
                                 columns=columns)

    def add_log(self, data: dict):
        with self._transaction() as (_, cursor):
            cursor.execute(
                sql.SQL("""
                    INSERT INTO {schema}.jobs (started, ended, status, status_msg, entity_names, feature_table, path)
                    VALUES({started}, {ended}, {status}, {status_msg}, {entity_names}, {feature_table}, {path})
                    """).format(
                    schema=sql.Identifier(self.schema),
                    started=sql.Literal(data['started']),
                    ended=sql.Literal(data['ended']),
                    status=sql.Literal(data['status']),
                    status_msg=sql.Literal(data['status_msg']),
                    entity_names=sql.Literal(data['entity_names']),
                    feature_table=sql.Literal(data['feature_table']),
                    path=sql.Literal(data['path']),
                ))
Beispiel #13
0
class DbController(object):
    """
    Provides postgres database connection pool and methods to manage the data
    """
    def __init__(self, user, password, host, port, database, pool_max_size):
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self.database = database
        self.pool_max_size = pool_max_size

        try:
            if (self.conn_pool):
                pass
        except:
            self.connect()

    def connect(self):
        """ Initialize database connection pool """
        try:
            self.conn_pool = SimpleConnectionPool(minconn=1,
                                                  maxconn=self.pool_max_size,
                                                  user=self.user,
                                                  password=self.password,
                                                  host=self.host,
                                                  port=self.port,
                                                  database=self.database)

        except (Exception, psycopg2.Error) as error:
            logger.fatal(f'Error connecting to PostgreSQL: {error}')

    def close(self):
        """ Close database connection pool """
        if (self.conn_pool):
            self.conn_pool.closeall
        logger.info('PostgreSQL connection pool is closed')

    def ping(self):
        """ Method to check availability of db. Used in /healtz/ready """
        try:
            # get connection from connection pool
            conn = self.conn_pool.getconn()
            cursor = conn.cursor()
            cursor.execute("SELECT version();")
            cursor.fetchone()
            # close cursor and release connection
            cursor.close()
            self.conn_pool.putconn(conn)
            return 0

        except (Exception, psycopg2.Error) as error:
            logger.error(f'Error connecting to PostgreSQL: {error}')
            return -1

    def select(self, query, args):
        """ Execute query and return results as json """
        try:
            conn = self.conn_pool.getconn()
            cursor = conn.cursor(cursor_factory=RealDictCursor)
            cursor.execute(query, args)
            # zip resultset into array of dictionaries
            record = json.dumps(cursor.fetchall())
            cursor.close()
            self.conn_pool.putconn(conn)
            # convert to json
            return json.loads(record)

        except (Exception, psycopg2.Error) as error:
            logger.error(f'Error executing the SELECT query: {error}')
            self.conn_pool.putconn(conn)
            return -1

    def insert(self, query, *args):
        try:
            params = tuple(arg for arg in args)

            conn = self.conn_pool.getconn()
            cursor = conn.cursor()
            cursor.execute(query, params)
            cursor.close()
            conn.commit()
            self.conn_pool.putconn(conn)

        except (Exception, psycopg2.Error) as error:
            logger.error(f'Error executing the INSERT query: {error}')
            conn.rollback()
            self.conn_pool.putconn(conn)
            return -1

    def insert_bulk(self, query, params):
        try:
            conn = self.conn_pool.getconn()
            cursor = conn.cursor()
            execute_batch(cursor, query, params)
            cursor.close()
            conn.commit()
            self.conn_pool.putconn(conn)

        except (Exception, psycopg2.Error) as error:
            logger.error(f'Error executing BULK INSERT query: {error}')
            conn.rollback()
            self.conn_pool.putconn(conn)
            return -1

    def update(self, query, args):
        try:
            params = args
            conn = self.conn_pool.getconn()
            cursor = conn.cursor()
            cursor.execute(query, params)
            cursor.close()
            conn.commit()
            self.conn_pool.putconn(conn)

        except (Exception, psycopg2.Error) as error:
            logger.error(f'Error executing UPDATE query: {error}')
            conn.rollback()
            self.conn_pool.putconn(conn)
            return -1

    def delete(self, query, args):
        try:
            params = args
            conn = self.conn_pool.getconn()
            cursor = conn.cursor()
            cursor.execute(query, params)
            cursor.close()
            conn.commit()
            self.conn_pool.putconn(conn)

        except (Exception, psycopg2.Error) as error:
            logger.error(f'Error executing DELETE query: {error}')
            conn.rollback()
            self.conn_pool.putconn(conn)
            return -1
Beispiel #14
0
class DBOperator(object):
    def __init__(self):
        self.db_conn_pool = SimpleConnectionPool(2,
                                                 3,
                                                 host=DST_DB_HOST,
                                                 port=int(DST_DB_PORT),
                                                 user=DST_DB_USER,
                                                 password=DST_DB_PASSWORD,
                                                 database=DST_DB_DATABASE)

    def pg_insert_return_id(self, table_name, field_list, insert_values):
        field_list = ','.join(field_list)
        conn = self.db_conn_pool.getconn()
        cursor = conn.cursor()
        if table_name == 'attr_record':
            cursor.execute(
                f"SELECT setval('{ETL_SCHEMA}.person_attr_record_id_seq', (SELECT max(id) FROM {ETL_SCHEMA}.attr_record));"
            )
        cursor.execute(
            f"insert into {ETL_SCHEMA}.{table_name} ({field_list}) values {insert_values} RETURNING id"
        )
        lt = cursor.fetchall()
        lt = [list(x) for x in lt]
        conn.commit()
        cursor.close()
        self.db_conn_pool.putconn(conn)
        return lt

    def append_pgsql(self, df, prop_info, tb_info):
        """table='atest' , schema='analysis_etl_gd_ele_fence'"""
        engine = create_engine(
            f"postgresql://{prop_info['user']}:{prop_info['password']}@{prop_info['host']}:{prop_info['port']}/{tb_info['database']}",
            max_overflow=0,
            pool_size=5,
            pool_timeout=30,
            pool_recycle=-1)

        pd_sql_engine = pd.io.sql.pandasSQL_builder(engine)
        pd_table = pd.io.sql.SQLTable(tb_info["tablename"],
                                      pd_sql_engine,
                                      frame=df,
                                      index=False,
                                      if_exists="append",
                                      schema=tb_info["schema"])
        pd_table.create()
        sio = StringIO()
        df.to_csv(sio, sep='|', encoding='utf-8', index=False)
        sio.seek(0)
        conn = self.db_conn_pool.getconn()
        cursor = conn.cursor()
        copy_cmd = f"COPY {tb_info['schema']}.{tb_info['tablename']} FROM STDIN HEADER DELIMITER '|' CSV"
        cursor.copy_expert(copy_cmd, sio)
        conn.commit()
        self.db_conn_pool.putconn(conn)

    def read_pgsql_to_pandas_dataframe(self, table_name):
        conn = self.db_conn_pool.getconn()
        try:
            df = pd.concat(
                pd.read_sql(f'''select * from {ETL_SCHEMA}.{table_name}''',
                            con=conn,
                            chunksize=1000))
        except (TypeError, ValueError):
            df = pd.read_sql(f'''select * from {ETL_SCHEMA}.{table_name}''',
                             con=conn)
        self.db_conn_pool.putconn(conn)
        return df

    def read_clue_rule_data(self):
        conn = self.db_conn_pool.getconn()
        df = pd.read_sql(
            f"select clue_value from {ETL_SCHEMA}.warn_clue_rule where " +
            f"clue_type = 'mac' and start_time <= now() and end_time >= now()",
            con=conn)
        self.db_conn_pool.putconn(conn)
        return df

    def read_attr_data(self):
        conn = self.db_conn_pool.getconn()
        df = pd.read_sql(f"select attr_value from {ETL_SCHEMA}.attr where " +
                         f"attr_type_id = 5 and sync_in_time = 't'",
                         con=conn)
        self.db_conn_pool.putconn(conn)
        return df

    def update_attr(self, df):
        """
        检查是否需要生成新的attr记录
        """
        if not df.empty:
            attr_mac_list = tuple(set(df['probe_data'].tolist()))
            if len(attr_mac_list) > 1:
                tmp_attr_mac_list = str(attr_mac_list)
            else:  # ==1
                tmp_attr_mac_list = str(attr_mac_list).replace(',', '')
            current_ts = int(time.time() * 1000)
            conn = self.db_conn_pool.getconn()
            cursor = conn.cursor()
            cursor.execute(
                f"UPDATE {ETL_SCHEMA}.attr SET update_time = {current_ts} WHERE attr_type_id = 5 and attr_value in {tmp_attr_mac_list} RETURNING id,attr_value"
            )
            conn.commit()
            query_result = cursor.fetchall()
            exist_attr_df = None
            insert_attr_df = None
            if query_result:
                exist_attr_df = pd.DataFrame(query_result,
                                             columns=['id', 'probe_data'])
                insert_mac_list = list(
                    (collections.Counter(attr_mac_list) - collections.Counter(
                        exist_attr_df['probe_data'].tolist())).elements())
            else:
                insert_mac_list = attr_mac_list
            insert_df = pd.DataFrame(insert_mac_list, columns=['attr_value'])
            if not insert_df.empty:
                insert_df['attr_type_id'] = 5
                insert_df['create_time'] = current_ts
                insert_df['update_time'] = current_ts
                logger.info("以下是attr数据")
                logger.info(insert_df)
                insert_data = insert_df.to_dict(orient="records")
                insert_data_str = ""
                for data in insert_data:
                    insert_data_str += f"({data['attr_type_id']},'{data['attr_value']}',{data['create_time']},{data['update_time']}),"
                insert_data_str = insert_data_str[:-1]
                cursor.execute(
                    f"SELECT setval('{ETL_SCHEMA}.attr_id_seq', (SELECT max(id) FROM zhaoqing_duanzhou_db.attr));"
                )
                cursor.execute(
                    f"insert into {ETL_SCHEMA}.attr (attr_type_id,attr_value,create_time,update_time) values {insert_data_str} RETURNING id,attr_value"
                )

                conn.commit()
                insert_query_result = cursor.fetchall()
                insert_attr_df = pd.DataFrame(insert_query_result,
                                              columns=['id', 'probe_data'])

            self.db_conn_pool.putconn(conn)
            if exist_attr_df is not None and insert_attr_df is not None:
                attr_df = pd.concat([exist_attr_df, insert_attr_df])
                return attr_df
            elif exist_attr_df is not None:
                return exist_attr_df
            else:
                return insert_attr_df
Beispiel #15
0
class Connection(object):
    """"""
    def __init__(self, database, host=None, port=None, user=None,
                 password=None, client_encoding="utf8",
                 minconn=1, maxconn=5,
                 **kwargs):

        self.host = "%s:%s" % (host, port)

        _db_args = dict(
            async=True,
            database=database,
            client_encoding=client_encoding,
            **kwargs
        )
        if host is not None:
            _db_args["host"] = host
        if port is not None:
            _db_args["port"] = port
        if user is not None:
            _db_args["user"] = user
        if password is not None:
            _db_args["password"] = password

        try:
            self._pool = SimpleConnectionPool(
                minconn=minconn, maxconn=maxconn, **_db_args)
        except Exception:
            logging.error("Cannot connect to PostgreSQL on %s", self.host,
                          exc_info=True)

    def __del__(self):
        self._pool.closeall()

    def _connect(self, callback=None):
        """Get an existing database connection."""
        conn = self._pool.getconn()

        callback = functools.partial(callback, conn)
        Poller(conn, (callback, ))._update_handler()

    @gen.coroutine
    def _cursor(self):
        conn = yield gen.Task(self._connect)
        cursor = conn.cursor()
        raise gen.Return(cursor)

    def putconn(self, conn, close=False):
        self._pool.putconn(conn, close=close)

    @gen.coroutine
    def query(self, query, parameters=()):
        """Returns a row list for the given query and parameters."""
        cursor = yield self._cursor()
        try:
            yield gen.Task(self._execute, cursor, query, parameters)
            column_names = [d[0] for d in cursor.description]
            raise gen.Return([Row(zip(column_names, row)) for row in cursor])
        finally:
            self.putconn(cursor.connection)
            cursor.close()

    @gen.coroutine
    def get(self, query, parameters=()):
        """Returns the (singular) row returned by the given query.

        If the query has no results, returns None.  If it has
        more than one result, raises an exception.
        """
        rows = yield self.query(query, parameters)
        if not rows:
            raise gen.Return(None)
        elif len(rows) > 1:
            raise Exception("Multiple rows returned for Database.get() query")
        else:
            raise gen.Return(rows[0])

    @gen.coroutine
    def execute(self, query, parameters=()):
        """Executes the given query."""
        cursor = yield self._cursor()
        try:
            yield gen.Task(self._execute, cursor, query, parameters)
        finally:
            self.putconn(cursor.connection)
            cursor.close()

    def _execute(self, cursor, query, parameters, callback=None):
        if not isinstance(parameters, (tuple, list)):
            raise

        try:
            cursor.execute(query, parameters)

            Poller(cursor.connection, (callback,))._update_handler()
        except psycopg2.OperationalError:
            logging.error("Error connecting to PostgreSQL on %s", self.host)
            self.putconn(cursor.connection, close=True)
Beispiel #16
0
class PostgresqlWrapper(object):
    """ Postgresql wrapper to heroku server to upload and download music data"""
    DATABASE_URL = "postgres://*****:*****@ec2-54-75-239-237.eu-west-1.compute.amazonaws.com:5432/d5jk6qjst0rku1"
    MUSIC_PATH = "genres"
    LOCALHOST_STING = "host='localhost' dbname='music' user='******' password='******'"

    def __init__(self, conn_num=3):
        self.__init_logger()
        self.log.info("Creating pool")
        self.conn_num = conn_num
        self.conn = psycopg2.connect(self.LOCALHOST_STING)
        self.cur = self.conn.cursor()
        self.pool = SimpleConnectionPool(self.conn_num, self.conn_num + 5,
                                         self.LOCALHOST_STING)
        self.register_adapters()
        self.create_table(False)

    def __init_logger(self):
        ch = logging.StreamHandler()  # console
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        ch.setFormatter(formatter)
        self.log = logging.getLogger(__name__)
        #if (self.log.hasHandlers()):
        #    self.log.handlers.clear()
        self.log.setLevel(logging.INFO)
        self.log.addHandler(ch)

    def create_table(self, clean=False):
        self.log.info("Creating table")
        if clean:
            self.cur.execute('drop table music')
            self.conn.commit()
        statement = "CREATE TABLE if not exists music \
        (id serial PRIMARY KEY, genre varchar(100), data BYTEA);"

        self.cur.execute(statement)
        self.conn.commit()

    def insert_song(self, genre, song):
        statement = "Insert into music(genre, data) values(%s, %s)"
        self.cur.execute(statement, (genre, song))

    def select_songs(self, limit=None, offset=None, genre=None):
        conn = self.pool.getconn()
        cur = conn.cursor()
        if genre is None:
            statement = "Select * from music order by id limit %s offset %s"
            self.log.info("Statement %s", statement % (limit, offset))
            cur.execute(statement, (limit, offset))
            self.log.info("Done with %s", statement % (limit, offset))
        else:
            statement = "Select * from music where genre = %s order by id limit %s offset %s"
            self.log.info("Statement %s", statement % (genre, limit, offset))
            cur.execute(statement, (genre, limit, offset))
        db_result = cur.fetchall()
        cur.close()
        self.pool.putconn(conn)
        return db_result

    def register_adapters(self):
        """ Handy adapters to transalte np.array to binary and vice versa """
        def _adapt_array(text):
            out = io.BytesIO()
            np.save(out, text)
            out.seek(0)
            return psycopg2.Binary(out.read())

        def _typecast_array(value, cur):
            if value is None:
                return None

            data = psycopg2.BINARY(value, cur)
            bdata = io.BytesIO(data)
            bdata.seek(0)
            return np.load(bdata)

        psycopg2.extensions.register_adapter(np.ndarray, _adapt_array)
        t_array = psycopg2.extensions.new_type(psycopg2.BINARY.values, "numpy",
                                               _typecast_array)
        psycopg2.extensions.register_type(t_array)
        self.log.info("Done register types")

    def to_database(self, folders=None, limit=1000):
        """ Process music to database """
        for root, _, files in os.walk(self.MUSIC_PATH):
            genre = root.split('/')[-1]
            if folders is not None and genre not in folders:
                continue
            for i, file_ in enumerate(files):
                if i == limit:
                    break
                self.log.info("Inserting song %s", file_)
                song = librosa.load(os.path.join(root, file_))[0]
                self.insert_song(genre, song)
        self.conn.commit()
        self.close_connection()

    def close_connection(self):
        self.log.info("Closing connection")
        self.pool.closeall()
        self.cur.close()
        self.conn.close()

    def fetch_songs(self, count, limit=50, genres=None):
        """ Fetch song in concurrent from database
            limit - how many song to fetch from one thread
            count - how many song to fetch
        """
        self.log.info("Start fetching %s songs", count)
        producer = []
        iter_ = 0
        offset = 0
        while offset < count:
            offset = limit * iter_
            if genres is not None:
                for genre in genres:
                    producer.append((limit, offset, genre))
            else:
                producer.append((limit, offset))
            iter_ += 1
        with Pool(self.conn_num) as pool:
            result = pool.starmap(self.select_songs, producer)
        return result
class ConnectionPool:
    pool = None

    def __init__(self):
        try:
            min_connections = 1
            max_connections = int(
                os.getenv("POSTGRES_MAX_CONNECTIONS")) if os.getenv(
                    "POSTGRES_MAX_CONNECTIONS") is not None else 10

            self.pool = SimpleConnectionPool(min_connections,
                                             max_connections,
                                             dbname=POSTGRES_DBNAME,
                                             user=POSTGRES_USER,
                                             host=POSTGRES_HOST,
                                             password=POSTGRES_PASSWORD,
                                             port=POSTGRES_PORT)
        except Exception as e:
            print(
                str(e), Logs.ERROR, {
                    "postgresql":
                    "{hostname}:{port}/{dbname}".format(hostname=POSTGRES_HOST,
                                                        port=POSTGRES_PORT,
                                                        dbname=POSTGRES_DBNAME)
                })

    def isAvailable(self):
        return (self.pool is not None)

    def closeAll(self):
        self.pool.closeall()

    def getConnection(self):
        conn = None

        try:
            conn = self.pool.getconn()
        except PoolError:
            print("error obteniendo una conexión a PostGIS, reintentando...",
                  Logs.ERROR)
            self.pool.closeall()  # close all active connections in the pool
            conn = self.pool.getconn()

        conn.autocommit = True
        return conn

    def closeConnection(self, conn):
        self.pool.putconn(conn, close=True)

    def runScript(self, scriptDir):
        insertCounter = 0
        geometryType = None

        conn = self.getConnection()
        for root, dirnames, filenames in os.walk(scriptDir):
            for filename in filenames:
                if filename[-4:] == '.sql':
                    scriptPath = root + "/" + filename

                    print("Ejecutando " + scriptPath, Logs.INFO)

                    with open(scriptPath, "r") as script:
                        # create db cursor
                        cursor = conn.cursor()

                        for query in script.read().split(
                                utils.QUERY_DELIMITER):
                            if query == "":
                                continue
                            try:
                                cursor.execute(query)

                                if query.startswith("INSERT INTO"):
                                    insertCounter += 1
                                elif geometryType is None and query.startswith(
                                        "SELECT AddGeometryColumn"):
                                    geometryType = query.split(",")[-2][1:-1]

                            except Exception as e:
                                print(str(e), Logs.ERROR, {"query": query})

                        cursor.close()

                        print("ok", Logs.INFO,
                              {"inserted_features": str(insertCounter)})

        self.closeConnection(conn)

        return insertCounter, geometryType
Beispiel #18
0
class Service:
    def __init__(self, username: str, password: str, public=False):
        self.cp = SimpleConnectionPool(1,
                                       5,
                                       user="******",
                                       password="******",
                                       host="sh.wtd2.top",
                                       port="5432",
                                       database="12306")
        if not public:
            self.password_check(username, password.encode('utf-8'))
            self.passenger_info()

    def password_check(self, username: str, password: bytes):
        pattern = re.compile('([^a-z0-9A-Z])+')
        if pattern.findall(username):
            raise Exception('Invalid username')
        sha_pass = hashlib.sha256(password).hexdigest()
        stmt = "select * from user_info where user_name = %s"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (username, ))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        if len(result) == 0:
            raise Exception('No such user')
            return False
        if result[0][2] != sha_pass:
            raise Exception('Password error')
            return False
        self.id = result[0][0]
        self.admin = result[0][3]
        return True

    def register(self, username: str, password: bytes):
        pattern = re.compile('([^a-z0-9A-Z])+')
        if pattern.findall(username):
            raise Exception('Invalid username')
        sha_pass = hashlib.sha256(password).hexdigest()
        stmt = "insert into user_info(user_name, user_pass, user_privilege) values (%s, %s, %s)"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (username, sha_pass, False))
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)

    def station_list(self):
        stmt = "select si.code, si.station, si.city, cp.province from station_info si join city_province cp on si.city = cp.city"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt)
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        return result

    def remain_query(self, dep: str, arr: str, date: str, exact: int = 1):
        if exact == 1:
            stmt = "select ti.seat_type, t.train_id, t.dep, t.arr, t.shown_code, t1.station, t2.station," \
                   "to_char(t.dep_time, 'HH24:MI'), to_char(t.arr_time, 'HH24:MI'), range / 60," \
                   "ip.price_yz, ip.price_rz, ip.price_sw, ip.price_yw, ip.price_rw," \
                   "remain_query(t.train_id, t.dep, t.arr, 1, %s), remain_query(t.train_id, t.dep, t.arr, 2, %s)," \
                   "remain_query(t.train_id, t.dep, t.arr, 3, %s), remain_query(t.train_id, t.dep, t.arr, 4, %s)," \
                   "remain_query(t.train_id, t.dep, t.arr, 5, %s)" \
                   "from ( " \
                   "    select t1.station_code dep_code, t2.station_code arr_code, t1.train_id train_id," \
                   "    t1.station_idx dep,t2.station_idx arr, t1.dep_time as dep_time, t2.arr_time as arr_time," \
                   "    t1.code_shown as shown_code, extract(epoch from (t2.arr_time - t1.dep_time + (t2.day_arr - t1.day_dep) * interval '1 day'))::int as range" \
                   "    from timetable t1 join timetable t2 on t1.train_id = t2.train_id where t1.station_shown and t2.station_shown and t1.station_code = %s and t2.station_code = %s and t1.station_idx < t2.station_idx ) t join train_info ti on t.train_id = ti.train_id join station_info t1 on t.dep_code = t1.code join station_info t2 on t.arr_code = t2.code join interval_price ip on t.train_id = ip.train_id and t.dep = ip.dep_idx and t.arr = ip.arr_idx order by t.range, t.dep_time;"
        elif exact == 0:
            stmt = "select ti.seat_type, t.train_id, t.dep, t.arr, t.shown_code, t1.station, t2.station," \
                   "to_char(t.dep_time, 'HH24:MI'), to_char(t.arr_time, 'HH24:MI'), range / 60," \
                   "ip.price_yz, ip.price_rz, ip.price_sw, ip.price_yw, ip.price_rw," \
                   "remain_query(t.train_id, t.dep, t.arr, 1, %s), remain_query(t.train_id, t.dep, t.arr, 2, %s)," \
                   "remain_query(t.train_id, t.dep, t.arr, 3, %s), remain_query(t.train_id, t.dep, t.arr, 4, %s)," \
                   "remain_query(t.train_id, t.dep, t.arr, 5, %s)" \
                   "from ( " \
                   "    select t1.station_code dep_code, t2.station_code arr_code, t1.train_id train_id," \
                   "    t1.station_idx dep,t2.station_idx arr, t1.dep_time as dep_time, t2.arr_time as arr_time," \
                   "    t1.code_shown as shown_code, extract(epoch from (t2.arr_time - t1.dep_time + (t2.day_arr - t1.day_dep) * interval '1 day'))::int as range" \
                   "    from timetable t1 join timetable t2 on t1.train_id = t2.train_id" \
                   "    where t1.station_shown and t2.station_shown and t1.station_code in (select code from station_info where city in (select city from station_info where code = %s))" \
                   "    and t2.station_code in (select code from station_info where city in (select city from station_info where code = %s))" \
                   "    and t1.station_idx < t2.station_idx ) t join train_info ti on t.train_id = ti.train_id join station_info t1 on t.dep_code = t1.code join station_info t2 on t.arr_code = t2.code join interval_price ip on t.train_id = ip.train_id and t.dep = ip.dep_idx and t.arr = ip.arr_idx order by t.range, t.dep_time;"
        else:
            stmt = "select ti.seat_type, t.train_id, t.dep, t.arr, t.shown_code, t1.station, t2.station," \
                   "to_char(t.dep_time, 'HH24:MI'), to_char(t.arr_time, 'HH24:MI'), range / 60," \
                   "ip.price_yz, ip.price_rz, ip.price_sw, ip.price_yw, ip.price_rw," \
                   "remain_query(t.train_id, t.dep, t.arr, 1, %s), remain_query(t.train_id, t.dep, t.arr, 2, %s)," \
                   "remain_query(t.train_id, t.dep, t.arr, 3, %s), remain_query(t.train_id, t.dep, t.arr, 4, %s)," \
                   "remain_query(t.train_id, t.dep, t.arr, 5, %s)" \
                   "from ( " \
                   "    select t1.station_code dep_code, t2.station_code arr_code, t1.train_id train_id," \
                   "    t1.station_idx dep,t2.station_idx arr, t1.dep_time as dep_time, t2.arr_time as arr_time," \
                   "    t1.code_shown as shown_code, extract(epoch from (t2.arr_time - t1.dep_time + (t2.day_arr - t1.day_dep) * interval '1 day'))::int as range" \
                   "    from timetable t1 join timetable t2 on t1.train_id = t2.train_id" \
                   "    where t1.station_shown and t2.station_shown and t1.station_code in (select code from station_info where station like station_name(%s))" \
                   "    and t2.station_code in (select code from station_info where station like station_name(%s))" \
                   "    and t1.station_idx < t2.station_idx ) t join train_info ti on t.train_id = ti.train_id join station_info t1 on t.dep_code = t1.code join station_info t2 on t.arr_code = t2.code join interval_price ip on t.train_id = ip.train_id and t.dep = ip.dep_idx and t.arr = ip.arr_idx order by t.range, t.dep_time;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (date, date, date, date, date, dep, arr))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        return result

    def transfer_query(self, dep: str, arr: str, exact: int = 1):
        if exact == 1:
            stmt = "with st as (     select code from station_info where code = %s ), ed as (     select code from station_info where code = %s ) select s1.station, to_char(t1.dep_time, 'HH24:MI'), t1.code_shown, to_char(t2.arr_time, 'HH24:MI'), s2.station, to_char(t3.dep_time, 'HH24:MI'), t3.code_shown, to_char(t4.arr_time, 'HH24:MI'), s3.station,        extract(epoch from (t4.arr_time - t1.dep_time + '1 day'::interval * (t4.day_arr - t3.day_dep + t2.day_arr - t1.day_dep) + (case when (t3.dep_time - t2.arr_time) >= '0:20:00'::interval then '0 day'::interval else '1 day'::interval end)))::int / 60 as total,        extract(epoch from (t3.dep_time - t2.arr_time + (case when (t3.dep_time - t2.arr_time) >= '0:15:00'::interval then '0 day'::interval else '1 day'::interval end)))::int / 60 as transfer from timetable t1 join timetable t2 on t1.train_id = t2.train_id and t2.station_code in (select si.code from station_info si     where si.code in (         select distinct t2.station_code from timetable t1         join timetable t2 on t1.train_id = t2.train_id and t1.station_idx < t2.station_idx and t1.station_code in (select code from st)     ) and si.code in (         select distinct t2.station_code from timetable t1         join timetable t2 on t1.train_id = t2.train_id and t1.station_idx > t2.station_idx and t1.station_code in (select code from ed)     ) ) and t1.station_idx < t2.station_idx join timetable t4 on t4.station_code in (select code from ed) and t4.train_id <> t2.train_id join timetable t3 on t3.train_id = t4.train_id and t3.station_code = t2.station_code and t3.station_idx < t4.station_idx join station_info s1 on s1.code = t1.station_code join station_info s2 on s2.code = t2.station_code join station_info s3 on s3.code = t4.station_code where t1.station_code in (select code from st) and t1.station_shown and t2.station_shown and t3.station_shown and t4.station_shown order by total limit 150;"
        elif exact == 0:
            stmt = "with st as (     select code from station_info where city in (select city from station_info where code = %s) ), ed as (     select code from station_info where city in (select city from station_info where code = %s) ) select s1.station, to_char(t1.dep_time, 'HH24:MI'), t1.code_shown, to_char(t2.arr_time, 'HH24:MI'), s2.station, to_char(t3.dep_time, 'HH24:MI'), t3.code_shown, to_char(t4.arr_time, 'HH24:MI'), s3.station,        extract(epoch from (t4.arr_time - t1.dep_time + '1 day'::interval * (t4.day_arr - t3.day_dep + t2.day_arr - t1.day_dep) + (case when (t3.dep_time - t2.arr_time) >= '0:20:00'::interval then '0 day'::interval else '1 day'::interval end)))::int / 60 as total,        extract(epoch from (t3.dep_time - t2.arr_time + (case when (t3.dep_time - t2.arr_time) >= '0:15:00'::interval then '0 day'::interval else '1 day'::interval end)))::int / 60 as transfer from timetable t1 join timetable t2 on t1.train_id = t2.train_id and t2.station_code in (select si.code from station_info si     where si.code in (         select distinct t2.station_code from timetable t1         join timetable t2 on t1.train_id = t2.train_id and t1.station_idx < t2.station_idx and t1.station_code in (select code from st)     ) and si.code in (         select distinct t2.station_code from timetable t1         join timetable t2 on t1.train_id = t2.train_id and t1.station_idx > t2.station_idx and t1.station_code in (select code from ed)     ) ) and t1.station_idx < t2.station_idx join timetable t4 on t4.station_code in (select code from ed) and t4.train_id <> t2.train_id join timetable t3 on t3.train_id = t4.train_id and t3.station_code = t2.station_code and t3.station_idx < t4.station_idx join station_info s1 on s1.code = t1.station_code join station_info s2 on s2.code = t2.station_code join station_info s3 on s3.code = t4.station_code where t1.station_code in (select code from st) and t1.station_shown and t2.station_shown and t3.station_shown and t4.station_shown order by total limit 150;"
        else:
            stmt = "with st as (     select code from station_info where station like station_name(%s)), ed as (     select code from station_info where station like station_name(%s)) select s1.station, to_char(t1.dep_time, 'HH24:MI'), t1.code_shown, to_char(t2.arr_time, 'HH24:MI'), s2.station, to_char(t3.dep_time, 'HH24:MI'), t3.code_shown, to_char(t4.arr_time, 'HH24:MI'), s3.station,        extract(epoch from (t4.arr_time - t1.dep_time + '1 day'::interval * (t4.day_arr - t3.day_dep + t2.day_arr - t1.day_dep) + (case when (t3.dep_time - t2.arr_time) >= '0:20:00'::interval then '0 day'::interval else '1 day'::interval end)))::int / 60 as total,        extract(epoch from (t3.dep_time - t2.arr_time + (case when (t3.dep_time - t2.arr_time) >= '0:15:00'::interval then '0 day'::interval else '1 day'::interval end)))::int / 60 as transfer from timetable t1 join timetable t2 on t1.train_id = t2.train_id and t2.station_code in (select si.code from station_info si     where si.code in (         select distinct t2.station_code from timetable t1         join timetable t2 on t1.train_id = t2.train_id and t1.station_idx < t2.station_idx and t1.station_code in (select code from st)     ) and si.code in (         select distinct t2.station_code from timetable t1         join timetable t2 on t1.train_id = t2.train_id and t1.station_idx > t2.station_idx and t1.station_code in (select code from ed)     ) ) and t1.station_idx < t2.station_idx join timetable t4 on t4.station_code in (select code from ed) and t4.train_id <> t2.train_id join timetable t3 on t3.train_id = t4.train_id and t3.station_code = t2.station_code and t3.station_idx < t4.station_idx join station_info s1 on s1.code = t1.station_code join station_info s2 on s2.code = t2.station_code join station_info s3 on s3.code = t4.station_code where t1.station_code in (select code from st) and t1.station_shown and t2.station_shown and t3.station_shown and t4.station_shown order by total limit 150;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (dep, arr))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        return result

    def price_query(self, train: str, dep: str, arr: str, date: str):
        stmt = 'select ti.seat_type, t1.code_shown, si1.station, si2.station, ip.* from interval_price ip join timetable t1 on ip.train_id = t1.train_id and ip.dep_idx = t1.station_idx join timetable t2 on ip.train_id = t2.train_id and ip.arr_idx = t2.station_idx join station_info si1 on t1.station_code = si1.code join station_info si2 on t2.station_code = si2.code join train_info ti on ip.train_id = ti.train_id where ip.train_id = %s and ip.dep_idx = %s and ip.arr_idx = %s'
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (train, dep, arr))
        result = cursor.fetchall()[0]
        stmt = 'select remain_query(%s, %s, %s, 1, %s), remain_query(%s, %s, %s, 2, %s), remain_query(%s, %s, %s, 3, %s), remain_query(%s, %s, %s, 4, %s), remain_query(%s, %s, %s, 5, %s);'
        cursor.execute(
            stmt, (train, dep, arr, date, train, dep, arr, date, train, dep,
                   arr, date, train, dep, arr, date, train, dep, arr, date))
        result = result + cursor.fetchall()[0]
        cursor.close()
        self.cp.putconn(conn)
        return result

    def passenger_info(self):
        stmt = "select passenger_id, passenger_name, idcard_number, phone_number from passenger_info where related_user = %s and shown = true order by passenger_id;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (self.id, ))
        self.passenger = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        return self.passenger

    def purchase_ticket(self, train: int, dep: int, arr: int, date: str,
                        hierarchy: int, passenger: int):
        stmt = "select * from passenger_info where passenger_id = %s and related_user = %s"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (passenger, self.id))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        if len(result) != 1:
            raise Exception('No such privilege')
        stmt = "select purchase_order(%s, %s, %s, %s, %s, %s);"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        try:
            cursor.execute(stmt, (passenger, train, dep, arr, hierarchy, date))
        except Exception as e:
            cursor.close()
            self.cp.putconn(conn)
            return None
        order_no = cursor.fetchall()[0][0]
        stmt = "select to_char(si.dep_date, 'YYYY年MM月DD日'), si.seat_cabin, si.seat_code from order_info oi join seat_info si on oi.train_id = si.train_id and oi.seat_id = si.seat_no where order_id = %s;"
        cursor.execute(stmt, (order_no, ))
        res = cursor.fetchall()
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)
        return list(res[0]) + [order_no]

    def refund_ticket(self, order_no: int):
        stmt = "select * from order_info where passenger_id in (select passenger_id from passenger_info where related_user = %s) and order_id = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (self.id, order_no))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        if len(result) != 1:
            raise Exception('No such privilege')
        stmt = "select refund_order(%s);"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        try:
            cursor.execute(stmt, (order_no, ))
        except Exception as e:
            cursor.close()
            self.cp.putconn(conn)
            return False
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)
        return True

    def timetable_query(self, train: str):
        stmt = "select t.station_idx, coalesce(t.code_shown, ti.train_full_name),s.station, to_char(t.arr_time, 'HH24:MI'),to_char(t.dep_time, 'HH24:MI'), t.day_arr, t.day_dep, extract(epoch from (t.dep_time - t.arr_time) + (t.day_dep - t.day_arr) * '1 day'::interval) / 60, t.train_id from timetable t join station_info s on t.station_code = s.code join train_info ti on t.train_id = ti.train_id left join interval_price ip on t.train_id = ip.train_id and ip.dep_idx = 1 and ip.arr_idx = t.station_idx where t.train_id in (select distinct train_id from timetable where code_shown like %s) and t.station_idx >= 1 and t.station_shown order by t.station_idx;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (train, ))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        return result

    def insert_passenger(self, name: str, idcard: str, phone: str):
        stmt = "insert into passenger_info (passenger_name, idcard_number, phone_number, related_user) values (%s, %s, %s, %s);"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (name, idcard, phone, self.id))
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)

    def delete_passenger(self, num: int):
        stmt = "select passenger_id from passenger_info where passenger_id = %s and related_user = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (num, self.id))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        if len(result) != 1:
            raise Exception('No such privilege')
        stmt = "update passenger_info set shown = false where passenger_id = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (num, ))
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)

    def update_passenger(self, num: int, name: str, idcard: str, phone: str):
        stmt = "select passenger_id from passenger_info where passenger_id = %s and related_user = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (num, self.id))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        if len(result) != 1:
            raise Exception('No such privilege')
        stmt = "update passenger_info set (passenger_name, idcard_number, phone_number) = (%s, %s, %s) where passenger_id = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (name, idcard, phone, num))
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)

    def add_stop(self, train_id: int, code: str, dt: str, at: str, dd: int,
                 ad: int):
        stmt = "select code from station_info where code = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (code, ))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        if len(result) != 1:
            raise Exception('No such station')
        stmt = "select add_stop(%s, %s, %s, %s, %s, %s);"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (train_id, code, dt, at, dd, ad))
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)

    def edit_stop(self, train_id: int, idx: int, code: str, dt: str, at: str,
                  dd: int, ad: int):
        stmt = "select code from station_info where code = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (code, ))
        result = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        if len(result) != 1:
            raise Exception('No such station')
        stmt = "select modify_stop(%s, %s, %s, %s, %s, %s, %s);"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (train_id, idx, code, dt, at, dd, ad))
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)

    def remove_stop(self, train_id: int, idx: int):
        stmt = "update timetable set station_shown = false where train_id = %s and station_idx = %s;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (train_id, idx))
        conn.commit()
        cursor.close()
        self.cp.putconn(conn)

    def order_query(self):
        stmt = "select oi.order_id as id, t1.code_shown as num, pi.passenger_name as name, s1.station as dep, s2.station as arr, oi.order_price as price, to_char(oi.create_time, 'YYYY-MM-DD HH24:MI') as create_time, si.seat_type, si.seat_cabin, si.seat_code, to_char(oi.dep_date::timestamp + t1.day_dep * '1 day'::interval + t1.dep_time, 'YYYY-MM-DD HH24:MI') as dt, to_char(oi.dep_date::timestamp + t2.day_arr * '1 day'::interval + t2.arr_time, 'YYYY-MM-DD HH24:MI') as at, ti.seat_type, pi.idcard_number from order_info oi join passenger_info pi on oi.passenger_id = pi.passenger_id join timetable t1 on oi.train_id = t1.train_id and oi.dep_idx = t1.station_idx join timetable t2 on oi.train_id = t2.train_id and oi.arr_idx = t2.station_idx join station_info s1 on t1.station_code = s1.code join station_info s2 on t2.station_code = s2.code join seat_info si on oi.train_id = si.train_id and oi.seat_id = si.seat_no and oi.dep_date = si.dep_date join train_info ti on oi.train_id = ti.train_id where oi.passenger_id in (select oi.passenger_id from passenger_info where pi.related_user = %s) and oi.order_status = 1 order by order_id desc;"
        conn = self.cp.getconn()
        cursor = conn.cursor()
        cursor.execute(stmt, (self.id, ))
        res = cursor.fetchall()
        cursor.close()
        self.cp.putconn(conn)
        return res
Beispiel #19
0
class DB(object):
    
    def __init__(self,config):
        try :
            self.DataName=config['datatype']
            del config['datatype']
        except:
            self.DataName='MYSQL'
            
        if self.DataName == 'MYSQL' :
            try:
                self.pool = mysql.connector.pooling.MySQLConnectionPool(**config)
                self.cnx=self.cur=None
            except mysql.connector.Error as err:
                # 这里需要记录操作日志
                logging.debug(err.msg)
                self.cnx=None
                raise BaseError(701) # 与数据库连接异常
        elif self.DataName == 'POSTGRESQL' :
            try :
                self.pool = SimpleConnectionPool(**config)
            except:
                raise BaseError(701) # 与数据库连接异常
        
        elif self.DataName == 'ORACLE' :
            try :
                if config['NLS_LANG'] :
                    os.environ['NLS_LANG']=config['NLS_LANG']
                del config['NLS_LANG']
            except:
                pass
            
            try :
                self.pool = cx_Oracle.SessionPool(**config)
            except :
                raise BaseError(701) # 与数据库连接异常
                
    def open(self):
        try :
            if self.DataName=='ORACLE' :
                self.__conn = self.pool.acquire()
                self.__cursor = self.__conn.cursor()
            elif self.DataName=='POSTGRESQL' :
                self.__conn = self.pool.getconn()
                self.__cursor = self.__conn.cursor()                
            else :  # 默认为Mysql
                self.__conn   = self.pool.get_connection()
                self.__cursor = self.__conn.cursor(buffered=True)                
                
            #self.__conn.autocommit=True
            self.__conn.autocommit=False
            self.cnx=self.__conn
            self.cur=self.__cursor
        except :
            raise BaseError(702) # 无法获得连接池
    
    def close(self):
        #关闭游标和数据库连接
        self.__conn.commit()
        if self.__cursor is not None:
            self.__cursor.close()
        
        if self.DataName == 'POSTGRESQL' :
            self.pool.putconn(self.__conn)#将数据库连接放回连接池中
        else :
            self.__conn.close()    

        
    def begin(self):
        self.__conn.autocommit=False
    
    def commit(self):
        self.__conn.commit()
        
    def rollback(self):
        self.__conn.rollback()
    
#---------------------------------------------------------------------------

    def findBySql(self,sql,params = {},limit = 0,join = 'AND',lock=False):
        """
            自定义sql语句查找
            limit = 是否需要返回多少行
            params = dict(field=value)
            join = 'AND | OR'
        """
        try :
            cursor = self.__getCursor()
            sql = self.__joinWhere(sql,params,join)
            cursor.execute(sql,tuple(params.values()))
            rows = cursor.fetchmany(size=limit) if limit > 0 else cursor.fetchall()
            result = [dict(zip(cursor.column_names,row)) for row in rows] if rows else None
            return result
        except:
            raise BaseError(706)
            
    
    def countBySql(self,sql,params = {},join = 'AND'):
        # 自定义sql 统计影响行数
        try:
            cursor = self.__getCursor()
            sql = self.__joinWhere(sql,params,join)
            cursor.execute(sql,tuple(params.values()))
            result = cursor.fetchone();
            return result[0] if result else 0
        except:
            raise BaseError(707)


    
    
    #def updateByPk(self,table,data,id,pk='id'):
    #    # 根据主键更新,默认是id为主键
    #    return self.updateByAttr(table,data,{pk:id})
    
    def deleteByAttr(self,table,params={},join='AND'):
        # 删除数据
        try :
            fields = ','.join(k+'=%s' for k in params.keys())
            sql = "DELETE FROM `%s` "% table
            sql = self.__joinWhere(sql,params,join)
            cursor = self.__getCursor()
            cursor.execute(sql,tuple(params.values()))
            self.__conn.commit()
            return cursor.rowcount
        
        #except:
        #    raise BaseError(704)
        except  Exception as err:
            raise BaseError(704,err._full_msg)        
    
    def deleteByPk(self,table,id,pk='id'):
        # 根据主键删除,默认是id为主键
        return self.deleteByAttr(table,{pk:id})
    
    def findByAttr(self,table,criteria = {}):
        # 根据条件查找一条记录
        return self.__query(table,criteria)
    
    def findByPk(self,table,id,pk='id'):
        return self.findByAttr(table,{'where':pk+'='+str(id)})
    
    def findAllByAttr(self,table,criteria={}):
        # 根据条件查找记录
        return self.__query(table,criteria,True)
    


    def exit(self,table,params={},join='AND'):
        # 判断是否存在
        return self.count(table,params,join) > 0

# 公共的方法 -------------------------------------------------------------------------------------
    def count(self,table,params={},join='AND'):
        # 根据条件统计行数
        try :
            sql = 'SELECT COUNT(*) FROM %s' % table
            
            if params :
                where ,whereValues   = self.__contact_where(params)
                sqlWhere= ' WHERE '+where if where else ''
                sql+=sqlWhere
            
            #sql = self.__joinWhere(sql,params,join)
            cursor = self.__getCursor()
            
            self.__display_Debug_IO(sql,tuple(whereValues)) #DEBUG
            
            if self.DataName=='ORACLE':
                cursor.execute(sql % tuple(whereValues))
            else :
                cursor.execute(sql,tuple(whereValues))
            #cursor.execute(sql,tuple(params.values()))
            result = cursor.fetchone();
            return result[0] if result else 0
        #except:
        #    raise BaseError(707)       
        except  Exception as err:
            try :
                raise BaseError(707,err._full_msg)
            except :
                raise BaseError(707)
                

    def getToListByPk(self,table,criteria={},id=None,pk='id'):
        # 根据条件查找记录返回List
        if ('where' not in criteria) and (id is not None) :
            criteria['where']=pk+ "='" + str(id) + "'"
        return self.__query(table,criteria,isDict=False)
    
    def getAllToList(self,table,criteria={},id=None,pk='id',join='AND'):
        # 根据条件查找记录返回List
        if ('where' not in criteria) and (id is not None) :
            criteria['where']=pk+ "='" + str(id) + "'"
        return self.__query(table,criteria,all=True,isDict=False)

    def getToObjectByPk(self,table,criteria={},id=None,pk='id'):
        # 根据条件查找记录返回Object
        if ('where' not in criteria) and (id is not None) :
            criteria['where']=pk+"='"+str(id)+"'"
        return self.__query(table,criteria)

    def getAllToObject(self,table,criteria={},id=None,pk='id',join='AND'):
        # 根据条件查找记录返回Object
        if ('where' not in criteria) and (id is not None) :
            criteria['where']=pk+"='"+str(id)+"'"
        return self.__query(table,criteria,all=True)


    def insert(self,table,data,commit=True):
        # 新增一条记录
        try :
            
            ''' 
                从data中分离含用SQL函数的字字段到funData字典中,
                不含SQL函数的字段到newData
            '''            
            funData,newData=self.__split_expression(data)
            
            funFields='';funValues=''
            
            # 拼不含SQL函数的字段及值
            fields = ','.join(k for k in newData.keys())
            values = ','.join(("%s", ) * len(newData))
            
            # 拼含SQL函数的字段及值            
            if funData :
                funFields = ','.join(k for k in funData.keys()) 
                funValues =','.join( v for  v in funData.values())
                
            # 合并所有字段及值 
            fields += ','+funFields if funFields else ''
            values += ','+funValues if funValues else ''
            sql = 'INSERT INTO %s (%s) VALUES (%s)'%(table,fields,values)
            cursor = self.__getCursor()
            
            for (k,v) in newData.items() :
                try:
                    if  isinstance(v, str) :
                        newData[k]="'%s'" % (v,)
                except :
                    pass
            
            
            self.__display_Debug_IO(sql,tuple(newData.values())) #DEBUG
            sql= sql % tuple(newData.values())
            
            if self.DataName=='POSTGRESQL' :
                sql+=' RETURNING id'
                
            cursor.execute(sql)
            
            #if self.DataName=='ORACLE':
                #sql= sql % tuple(newData.values())
                #cursor.execute(sql)
            #else :
                #cursor.execute(sql,tuple(newData.values()))
                
            if self.DataName=='ORACLE':
                # 1. commit 一定要为假
                # 2. Oracle Sequence 的命名规范为: [用户名.]SEQ_表名_ID
                # 3. 每张主表都应该有ID
                t_list=table.split('.')
                if len(t_list)>1 :
                    SEQ_Name= t_list[0]+'.SEQ_'+t_list[1]+'_ID'
                else :
                    SEQ_Name='SEQ_'+t_list[0]+'_ID'
                    
                cursor.execute('SELECT %s.CURRVAL FROM dual' % SEQ_Name.upper())
                
                result = cursor.fetchone()
                insert_id= result[0] if result else 0                
                #insert_id=cursor.rowcount
            elif self.DataName=='MYSQL' :
                insert_id = cursor.lastrowid
            elif self.DataName=='POSTGRESQL':
                item = cursor.fetchone()
                insert_id = item[0]
            
            if commit : self.commit()
            return insert_id
        
        except  Exception as err:
            try :
                raise BaseError(705,err._full_msg)
            except :
                raise BaseError(705,err.args)
        
    def update(self,table,data,params={},join='AND',commit=True,lock=True):
        # 更新数据
        try :
            fields,values  = self.__contact_fields(data)
            if params :
                where ,whereValues   = self.__contact_where(params)
            
            values.extend(whereValues) if whereValues else values
            
            sqlWhere= ' WHERE '+where if where else ''

            cursor = self.__getCursor()
            
            if commit : self.begin()
            
            if lock :
                sqlSelect="SELECT %s From %s %s for update" % (','.join(tuple(list(params.keys()))),table,sqlWhere)
                sqlSelect=sqlSelect % tuple(whereValues)
                cursor.execute(sqlSelect)  # 加行锁
                #cursor.execute(sqlSelect,tuple(whereValues))  # 加行锁
                
            sqlUpdate = "UPDATE %s SET %s "% (table,fields) + sqlWhere
            
            for index,val in enumerate(values):
                if isinstance(val,str) :
                    values[index]="'"+val+"'"

            self.__display_Debug_IO(sqlUpdate,tuple(values)) #DEBUG
            sqlUpdate = sqlUpdate % tuple(values)
            cursor.execute(sqlUpdate)

            #cursor.execute(sqlUpdate,tuple(values))

            if commit : self.commit()

            return cursor.rowcount

        except  Exception as err:
            try :
                raise BaseError(705,err._full_msg)
            except :
                raise BaseError(705,err.args)
                
        
    def updateByPk(self,table,data,id,pk='id',commit=True,lock=True):
        # 根据主键更新,默认是id为主键
        return self.update(table,data,{pk:id},commit=commit,lock=lock)
    
    def delete(self,table,params={},join='AND',commit=True,lock=True):
        # 更新数据
        try :
            data={}
            fields,values  = self.__contact_fields(data)
            if params :
                where ,whereValues   = self.__contact_where(params)
    
            values.extend(whereValues) if whereValues else values
    
            sqlWhere= ' WHERE '+where if where else ''
    
            cursor = self.__getCursor()
    
            if commit : self.begin()
    
            #if lock :
                #sqlSelect="SELECT %s From %s %s for update" % (','.join(tuple(list(params.keys()))),table,sqlWhere)
                #sqlSelect=sqlSelect % tuple(whereValues)
                #cursor.execute(sqlSelect)  # 加行锁
                ##cursor.execute(sqlSelect,tuple(whereValues))  # 加行锁
    
            sqlDelete = "DELETE FROM %s %s"% (table,sqlWhere)
    
            for index,val in enumerate(values):
                if isinstance(val,str) :
                    values[index]="'"+val+"'"
    
            self.__display_Debug_IO(sqlDelete,tuple(values)) #DEBUG
            sqlDelete = sqlDelete % tuple(values)
            cursor.execute(sqlDelete)
    
            #cursor.execute(sqlUpdate,tuple(values))
    
            if commit : self.commit()
    
            return cursor.rowcount
    
        except  Exception as err:
            try :
                raise BaseError(705,err._full_msg)
            except :
                raise BaseError(705,err.args) 

    def deleteByPk(self,table,id,pk='id',commit=True,lock=True):
        # 根据主键更新,默认是id为主键
        return self.delete(table,{pk:id},commit=commit,lock=lock)
    
# 内部私有的方法 -------------------------------------------------------------------------------------

    def __display_Debug_IO(self,sql,params) :
        if DEBUG :
            debug_now_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print('[S '+debug_now_time+' SQL:] '+(sql % params) if params else sql)
        
    def __get_connection(self):
        return self.pool.get_connection()
    
    def __getCursor(self):
        """获取游标"""
        if self.__cursor is None:
            self.__cursor = self.__conn.cursor()
        return self.__cursor
    def getCursor(self):
        """获取游标"""
        if self.__cursor is None:
            self.__cursor = self.__conn.cursor()
        return self.__cursor    

    def __joinWhere(self,sql,params,join):
        # 转换params为where连接语句
        if params:
            
            funParams={};newParams={};newWhere='';funWhere=''
            
            # 从params中分离含用SQL函数的字字段到Params字典中
            for (k,v) in params.items():
                if 'str' in str(type(v)) and '{{' == v[:2] and '}}'==v[-2:]  :
                    funParams[k]=v[2:-2]
                else:
                    newParams[k]=v

            # 拼 newParams 条件         
            keys,_keys = self.__tParams(newParams)
            newWhere = ' AND '.join(k+'='+_k for k,_k in zip(keys,_keys)) if join == 'AND' else ' OR '.join(k+'='+_k for k,_k in zip(keys,_keys))
            
            # 拼 funParams 条件
            if funParams :
                funWhere = ' AND '.join(k+'='+v for k,v in funParams.items()) if join == 'AND' else ' OR '.join(k+'='+v for k,v in funParams.items())
            
            # 拼最终的 where
            where=((newWhere+' AND ' if newWhere else '')+funWhere if funWhere else newWhere) if join=='AND' else ((newWhere+' OR ' if newWhere else '')+funWhere if funWhere else newWhere)
                
            #--------------------------------------
            #keys,_keys = self.__tParams(params)
            #where = ' AND '.join(k+'='+_k for k,_k in zip(keys,_keys)) if join == 'AND' else ' OR '.join(k+'='+_k for k,_k in zip(keys,_keys))
            sql+=' WHERE ' + where
        return sql
    
    def __tParams(self,params):
        keys = [k  if k[:2]!='{{' else k[2:-2] for k in params.keys()]
        _keys = ['%s' for k in params.keys()]
        return keys,_keys
    
    def __query(self,table,criteria,all=False,isDict=True,join='AND'):
        '''
           table    : 表名
           criteria : 查询条件dict
           all      : 是否返回所有数据,默认为False只返回一条数据,当为真是返回所有数据
           isDict   : 返回格式是否为字典,默认为True ,即字典否则为数组 
        '''
        try : 
            if all is not True:
                criteria['limit'] = 1  # 只输出一条
            sql,params = self.__contact_sql(table,criteria,join) #拼sql及params
            '''
            # 当Where为多个查询条件时,拼查询条件 key 的 valuse 值
            if 'where' in criteria and 'dict' in str(type(criteria['where'])) :
                params = criteria['where']
                #params = tuple(params.values())
                where ,whereValues   = self.__contact_where(params)
                sql+= ' WHERE '+where if where else ''
                params=tuple(whereValues)
            else :
                params = None
            '''
            #__contact_where(params,join='AND')
            cursor = self.__getCursor()
            
            self.__display_Debug_IO(sql,params) #DEBUG
            
            #if self.DataName=="ORACLE":
                #sql="select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % ()
                #pass
            
            
            cursor.execute(sql,params if params else ())
            
            rows = cursor.fetchall() if all else cursor.fetchone()
           
            if isDict :
                result = [dict(zip(cursor.column_names,row)) for row in rows] if all else dict(zip(cursor.column_names,rows)) if rows else {}
            else :
                result = [row for row in rows] if all else rows if rows else []
            return result
        except  Exception as err:
            try :
                raise BaseError(706,err._full_msg)
            except :
                raise BaseError(706)

            
    def __contact_sql(self,table,criteria,join='AND'):
        sql = 'SELECT '
        if criteria and type(criteria) is dict:
            #select fields
            if 'select' in criteria:
                fields = criteria['select'].split(',')
                sql+= ','.join(field.strip()[2:-2] if '{{' == field.strip()[:2] and '}}'==field.strip()[-2:] else field.strip() for field in fields)
            else:
                sql+=' * '
            #table
            sql+=' FROM %s' % table
            
            #where
            whereValues=None
            if 'where' in criteria:
                if 'str' in str(type(criteria['where'])) :   # 当值为String时,即单一Key时
                    sql+=' WHERE '+ criteria['where']
                else :                                       # 当值为dict时,即一组key时
                    params=criteria['where']
                    #sql+= self.__joinWhere('',params,join)
                    #sql+=self.__contact_where(params,join)
                    where ,whereValues   = self.__contact_where(params)
                    sql+= ' WHERE '+where if where else ''
                    #sql=sql % tuple(whereValues)
                    
            #group by
            if 'group' in criteria:
                sql+=' GROUP BY '+ criteria['group']
            #having
            if 'having' in criteria:
                sql+=' HAVING '+ criteria['having']
            
            if self.DataName=='MYSQL' :
                #order by
                if 'order' in criteria:
                    sql+=' ORDER BY '+ criteria['order']
                #limit
                if 'limit' in criteria:
                    sql+=' LIMIT '+ str(criteria['limit'])
                #offset
                if 'offset' in criteria:
                    sql+=' OFFSET '+ str(criteria['offset'])
            elif (self.DataName=='POSTGRESQL') :
                #order by
                if 'order' in criteria:
                    sql+=' ORDER BY '+ criteria['order']                
                if  'limit' in criteria :
                    # 取 offset,rowcount
                    arrLimit=(str(criteria['limit']).split('limit ').pop()).split(',')
                    strOffset = arrLimit[0]
                    try :
                        strRowcount  = arrLimit[1]
                    except :
                        strOffset    = '0'
                        strRowcount  = '1'
                    sql+='  LIMIT %s OFFSET %s' %(strRowcount,strOffset)
                
            elif (self.DataName=='ORACLE') and ('limit' in criteria) :
                # 取 offset,rowcount
                arrLimit=(str(criteria['limit']).split('limit ').pop()).split(',')
                strOffset = arrLimit[0]
                try :
                    strRowcount  = arrLimit[1]
                except :
                    strOffset    = '0'
                    strRowcount  = '1'
                
                # 处理 order by 
                if 'order' in criteria :
                    strOrder = criteria['order']
                else :
                    strOrder = 'ROWNUM'
                # 以下Sql是针对 Oracle 的大数据查询效率
                sql="select * from(select * from(select t.*,row_number() over(order by %s) as rownumber from(%s) t) p where p.rownumber>%s) where rownum<=%s" % (strOrder,sql,strOffset,strRowcount)
            elif (self.DataName=='ORACLE') and ('order' in criteria) :
                sql+=' ORDER BY '+ criteria['order']

        else:
            sql+=' * FROM %s' % table
            
        return sql,whereValues

    # 将字符串和表达式分离
    def __split_expression(self,data) :
        funData={};newData={};funFields=''
                                
        # 从data中移出含用SQL函数的字字段到funData字典中
        for (k,v) in data.items():
            if 'str' in str(type(v)) and '{{' == v[:2] and '}}'==v[-2:] :
                funData[k]=v[2:-2]
            else : newData[k]=v
        
        return (funData,newData)
        
        
    # 拼Update字段    
    def __contact_fields(self,data) :
    
        funData,newData=self.__split_expression(data)
        if funData :
            funFields = ','.join(k+'=%s'  % (v) for k,v in funData.items())
        fields = ','.join(k+'=%s' for k in newData.keys())
            
        
        # fields 与 funFields 合并
        if funData :
            fields = ','.join([fields,funFields]) if fields else funFields
            
        values = list(newData.values())
        
        return (fields,values)
    
    def __hasKeyword(self,key) :
        if '{{}}' in key : return True
        if 'in ('  in key : return True
        if 'like ' in key : return True
        if '>' in key : return True
        if '<' in key : return True
        return False
        
    # 拼Where条件
    def __contact_where(self,params,join='AND') :
        funParams,newParams=self.__split_expression(params)
        
        # 拼 newParams 条件
        keys,_keys = self.__tParams(newParams)
        newWhere = ' AND '.join(k+'='+_k for k,_k in zip(keys,_keys)) if join == 'AND' else ' OR '.join(k+'='+_k for k,_k in zip(keys,_keys))
        values = list(newParams.values())
    
        # 拼 funParams 条件
        #funWhere = ' AND '.join(('`' if k else '') +k+('`' if k else '')+ (' ' if self.__hasKeyword(v) else '=') +v for k,v in funParams.items()) if join == 'AND' else ' OR '.join('`'+k+'`'+(' ' if self.__hasKeyword(v) else '=')+v for k,v in funParams.items())
        
        
        funWhere = ' AND '.join(k+ (' ' if self.__hasKeyword(v) else '=' if k else '') +v for k,v in funParams.items()) if join == 'AND' else ' OR '.join(k+(' ' if self.__hasKeyword(v) else '=' if k else '')+v for k,v in funParams.items())

        # 拼最终的 where
        where=((newWhere+' AND ' if newWhere else '')+funWhere if funWhere else newWhere) if join=='AND' else ((newWhere+' OR ' if newWhere else '')+funWhere if funWhere else newWhere)
        return (where,values)
    
    
    def get_ids(self,list): #从getAllToList返回中提取id
        try:
            test=list[0][0]
            dimension=2
        except:
            dimension=1
            
        ids=[]
        if dimension>1 : 
            for i in range(len(list)) : ids.append(str(list[i][0]))
        else : 
            for i in range(len(list)) : ids.append(str(list[i]))
        
        return ','.join(ids)    
Beispiel #20
0
class GPPool:
    dbname = 'localhost'
    user = '******'
    host = '127.0.0.1'
    password = '******'
    port = 5432
    gp_pool = None

    def __init__(self,
                 gp_host,
                 gp_port,
                 gp_dbname,
                 gp_user,
                 password,
                 minconn=1,
                 maxconn=5,
                 multithreading=True):
        self.host = gp_host
        self.port = gp_port
        self.dbname = gp_dbname
        self.user = gp_user
        self.password = password
        if multithreading:
            # 可用在多线程应用程序中
            self.gp_pool = ThreadedConnectionPool(minconn,
                                                  maxconn,
                                                  host=gp_host,
                                                  port=gp_port,
                                                  dbname=gp_dbname,
                                                  user=gp_user,
                                                  password=password)
        else:
            # 仅用于单线程应用程序中
            self.gp_pool = SimpleConnectionPool(minconn,
                                                maxconn,
                                                host=gp_host,
                                                port=gp_port,
                                                dbname=gp_dbname,
                                                user=gp_user,
                                                password=password)

    def exe_conn(self, sql):
        conn = self.gp_pool.getconn()  # 获取连接
        cursor = conn.cursor()  # 获取cursor
        cursor.execute(sql)  # 用于执行SQL语句
        # cursor.mogrify(query)  #返回生成的sql脚本, 用以查看生成的sql是否正确
        conn.commit()  # 没次操作都要提交
        self.gp_pool.putconn(conn)  # 放回连接, 防止其他程序pg无连接可用
        return cursor

    def fetchone_sql(self, sql):
        cursor = self.exe_conn(sql)
        # desc = cursor.description  # cursor 的具体描述信息
        fetchone = cursor.fetchone()  # 获取执行结果中的一条记录
        cursor.close()  # 关闭当前连接的游标
        return fetchone

    def fetchall_sql(self, sql):
        cursor = self.exe_conn(sql)
        fetchall = cursor.fetchall()  # 获取SQL执行结果中的所有记录,返回值是一个元组的列表,每一条记录是一个元组
        cursor.close()
        return fetchall

    def fetchmany_sql(self, sql, size=1):
        cursor = self.exe_conn(sql)
        fetchall = cursor.fetchmany(size)  # 获取SQL执行结果中指定条数的记录,记录数由size指定
        cursor.close()
        return fetchall

    def exe_sql(self, sql):
        cursor = self.exe_conn(sql)
        cursor.close()

    def close_all(self):
        self.gp_pool.closeall()
Beispiel #21
0
class PSQL(object):
    def __init__(self):
        self.pool = None
        self.conn = None
        self.cursor = None
        self.need_update = None

    def close(self):
        if self.cursor is not None:
            self.cursor.close()
            self.cursor = None
            if self.need_update:
                self.conn.commit()

    def rollback(self):
        self.conn.rollback()

    def __getCursor(self):
        if self.cursor is None:
            self.cursor = self.conn.cursor()
        return self.cursor

    def get_conn(self, key=None):
        if key:
            if not self.pool:
                self.pool = SimpleConnectionPool(**config.PostgresqlDbConfig)
            self.conn = self.pool.getconn(key)
        else:
            self.conn = connect(**config_conn)
        self.conn.autocommit = False

    def put_conn(self, key, close=False):
        self.pool.putconn(self.conn, key, close)

    def insert(self, data={}, table=None):
        keys, vals = [], []
        for k, v in data.items():
            keys.append(k)
            vals.append(v)
        val_str = ','.join(['%s'] * len(vals))
        sql = 'INSERT INTO %s (%s) VALUES (%s)' % (table, ','.join(keys),
                                                   val_str)
        self.cursor = self.__getCursor()
        self.cursor.execute(sql, tuple(vals))
        self.need_update = 1

    def findBySql(self, sql='', args=None):
        self.cursor = self.__getCursor()
        self.cursor.execute(sql, tuple(args))
        return self.cursor.fetchall()

    def group_location(self, locate={}):
        cond = dict(locate)
        args = []
        sql = ' WHERE '
        if 'union' in cond:
            union = ' %s ' % cond['union']
            del cond['union']
        else:
            union = ' AND '
        for k in cond:
            v = cond[k]
            if isinstance(v, str):
                if v.startswith('%') or v.endswith('%'):
                    sql += "%s LIKE %%s" % k
                else:
                    sql += "%s=%%s" % k
                args.append(v)
            elif isinstance(v, int) or isinstance(v, float):
                sql += "%s=%%s" % k
                args.append(v)
            elif isinstance(v, list):
                sql += '%s in (%s)' % (k, ','.join(['%s'] * len(v)))
                args.extend(v)
            sql += union
        return sql[:-5], tuple(args)

    def query(self, table=None, columns=None, locate={}, order_by='', limit=0):
        if not columns:
            columns = getattr(config, 'column_%s' % table)
            _query = ','.join(columns)
        elif isinstance(columns, str):
            _query = columns
        elif isinstance(columns, list):
            _query = ','.join(columns)
        else:
            raise BaseException
        sql = "SELECT %s FROM %s " % (
            _query,
            table,
        )

        args = ()
        if locate:
            sql_cond, args = self.group_location(locate)
            sql += sql_cond
        if order_by:
            sql += ' ORDER BY %s' % order_by
            # args += (order_by, )
        if limit:
            sql += ' LIMIT %s' % limit
            # args += (limit, )

        self.cursor = self.__getCursor()
        # print(sql, '--sql')
        self.cursor.execute(sql, args)
        res = self.cursor.fetchall()
        if isinstance(columns, str):
            return res[0][0]
        if not res:
            return res

        if limit == 1:
            data = dict([(column.split(' as ')[-1], res[0][idx])
                         for idx, column in enumerate(columns)])
        else:
            data = [
                dict([(column.split(' as ')[-1], item[idx])
                      for idx, column in enumerate(columns)]) for item in res
            ]
        return data

    def update(self, table=None, data={}, locate={}):
        sql = 'UPDATE %s SET ' % table
        args = []
        for k in data:
            if k[-1] in '+-':
                sql += '%s=%s%s%%s,' % (
                    k[:-1],
                    k[:-1],
                    k[-1],
                )
                args.append(str(data[k]))
            else:
                sql += "%s=$$%s$$," % (k, str(data[k]))
        sql = sql[:-1]
        if locate:
            sql_cond, _args = self.group_location(locate)
            sql += sql_cond
            args.extend(_args)

        self.cursor = self.__getCursor()
        self.cursor.execute(sql, tuple(args))
        self.need_update = 1

    def delete(self, table=None, locate={}, delete_all=False):
        sql = 'DELETE FROM %s ' % table
        args = ()
        if locate:
            sql_cond, args = self.group_location(locate)
            sql += sql_cond
        # 防止情况
        elif not locate and not delete_all:
            print("can't delete all")
            raise Exception
        self.cursor = self.__getCursor()
        # print(sql)
        self.cursor.execute(sql, args)
        self.need_update = 1
Beispiel #22
0
    host = parsed.hostname
    port = parsed.port
    if port is None:
        port = '5432' # postgres default port
    dsn = "dbname={} host={} port={}".format(dbname, host, port)
    if user:
        dsn += ' username={}'.format(user)
    if password:
        dsn += ' password={}'.format(password)
    return dsn

if __name__ == "__main__":
    # Getting dsn from console arguments
    # postgres://user:password@localhost:5432/test_erp
    if 'postgres' not in urlparse.uses_netloc:
        # Teach urlparse about postgres:// URLs.
        urlparse.uses_netloc.append('postgres')
    if len(sys.argv) > 1:
        conn_string = url_to_dsn(sys.argv[1])
    else:
        conn_string = url_to_dsn("postgres://localhost:5432/test_erp")

    # creating pool
    pool = SimpleConnectionPool(1, 5, dsn=conn_string)
    for i in xrange(1,6):
        print "Question {}:\n\r{}".format(i, getattr(sys.modules[__name__], 'question{}'.format(i)).__doc__)
        conn = pool.getconn()
        print getattr(sys.modules[__name__], 'question{}'.format(i))(conn)
        pool.putconn(conn)
        print "="*20
    pool.closeall()
Beispiel #23
0
class PgDbModel:

    def __init__(self, conf: dict):
        '''
        :param conf:  dict(
        database='postgres',    # 库名
        user='******',        # 用户
        password='******',      # 密码
        host='127.0.0.1',       # IP
        port='5432',            # 端口
        minconn=1,              # 最小连接数
        maxconn=5               # 最大连接数
        )
        '''
        # 创建连接池
        self.dpool = DbPool(**conf)

    # 获取连接
    def get_conn(self):
        return self.dpool.getconn()

    # 获取游标
    def get_cur(self, conn, dictcur=False):
        if dictcur == True:
            return conn.cursor(cursor_factory=RealDictCursor)
        else:
            return conn.cursor()

    # 提交事务
    def conn_commit(self, conn):
        return conn.commit()

    # 回滚事务
    def conn_rollback(self, conn):
        return conn.rollback()

    # 回收连接
    def conn_close(self, conn):
        return self.dpool.putconn(conn)

    # 关闭连接池
    def close_all(self):
        return self.dpool.closeall()

    # 返回查询结果
    def fetch_all(self, cur):
        return cur.fetchall()

    # 返回行数
    def row_count(self, cur):
        return cur.rowcount

    # 批量操作
    def executemany(self, cur, sql, data_list):
        if isinstance(data_list, list):
            cur.executemany(sql, data_list)

    # 单条操作
    def execute(self, cur, sql, data):
        if data is not None:
            cur.execute(sql, data)
        else:
            cur.execute(sql)

    # 查询操作
    # 慢查询日志装饰器
    @slow_log(
        'query', configs.settings.SLOW_LOGGER_QUERY_THRESHOLD,
        slow_off=configs.settings.SLOW_LOGGER_QUERY_OFF,
        log_title=configs.settings.SLOW_LOGGER_QUERY_TITLE,
        timezone=configs.settings.SLOW_LOGGER_QUERY_TIMEZONE
    )
    def select(self, cur, sql, data=None):
        self.execute(cur, sql, data)
        return self.fetch_all(cur)

    # 批量插入
    def insert_all(self, cur, sql, data_list):
        self.executemany(cur, sql, data_list)
        return self.row_count(cur)

    # 单条插入
    def insert_one(self, cur, sql, data):
        self.execute(cur, sql, data)
        return self.row_count(cur)

    # 批量更新
    def update_all(self, cur, sql, data_list):
        self.executemany(cur, sql, data_list)
        return self.row_count(cur)

    # 单条更新
    def update_one(self, cur, sql, data):
        self.execute(cur, sql, data)
        return self.row_count(cur)

    # 返回上次执行的SQL
    def get_query(self, cur):
        return cur.query.decode()

    # 检校合成的SQL是否正确
    def check_sql(self, cur, sql, data):
        return cur.mogrify(sql, data).decode()
Beispiel #24
0
class DBUtil(): 
    def __init__(self, db_name, config_file): 
        self.db_name = db_name
        self.config_file = config_file 

    def connect_to_db(self): 
        db_name = self.db_name 
        cp = configparser.ConfigParser()
        cp.read(self.config_file)
        password = cp.get(db_name, "password")
        user = cp.get(db_name, "user")
        database = cp.get(db_name, "database")
        host = cp.get(db_name, "host") 
        port = cp.get(db_name, "port") 


        kwargs = {"host":host,"password":password, 
            "user":user,"dbname":database, "port":port}

        self.conn_pool = SimpleConnectionPool(1, 3, **kwargs)
        
    def get_conn(self): 
        try: 
            conn = self.conn_pool.getconn() 
        except: 
            self.connect_to_db()
            conn = self.conn_pool.getconn()   
        return conn 
    
    def get_df_from_query(self, query, params=None, pprint=False, to_df=True, server_cur=False, itersize=20000):
        try:
            conn = self.conn_pool.getconn()
        except:
            self.connect_to_db()
            conn = self.conn_pool.getconn()
        
        if pprint==True:
            print(self.format_sql(query))

        if server_cur == True:
            cur = conn.cursor('server_side_cursor')
            cur.itersize = itersize
            cur.execute(query, params)            
            return cur
        else:
            with conn.cursor() as cur:
                cur.execute(query, params)
                data = cur.fetchall()
                columns = [desc[0] for desc in cur.description]
            
        self.conn_pool.putconn(conn)
        
        if to_df == True: 
            df = pd.DataFrame(data, columns=columns)
            return df
        else:
            return data, columns 
    
    def get_arr_from_query(self, query, params=None): 
        results_arr = [] 
        conn = self.get_conn()       
            
        with conn.cursor() as cur: 
            cur.execute(query, params)

            data = cur.fetchall()
            columns = [desc[0] for desc in cur.description]
            results_arr.append(columns)

        self.conn_pool.putconn(conn)
        for row in data: 
            results_arr.append(list(row)) 
        return results_arr

    def update_db(self, query, params=None): 

        conn = self.get_conn()
        
        with conn.cursor() as cur: 
            try: 
                cur.execute(query, params)
            except Exception as e: 
                print(e)
                self.conn_pool.putconn(conn)
                raise e

        conn.commit()
        self.conn_pool.putconn(conn)
        return 0  
    
    def write_df_to_table(self, df, tablename): 
        
        arr = [] 
        columns = df.columns 
        for index, row in df.iterrows(): 
            row = [ str(i)[:255] for i in row.tolist()]
            arr.append(row) 
        self.write_arr_to_table(arr, tablename, columns)
    
    def write_arr_to_table(self, arr, tablename, columns, new_table=True): 
        
        conn = self.get_conn()       

        column_str = "({0})".format( ",".join(columns))
        column_def = "({0} varchar(256) )".format( " varchar(256),".join(columns))
        value_str = "({0})".format( ",".join(["%s" for c in columns]))
        
        sql = "insert into {0} {1} values {2};".format(tablename, column_str, value_str)
        
        try: 
            print(sql, arr[0])
        except IndexError as e: 
            print(e, len(arr))
         
        with conn.cursor() as cur: 
            if new_table==True: 
                cur.execute("DROP TABLE IF EXISTS {0}".format(tablename))
                cur.execute("CREATE TABLE {0} {1}".format(tablename, column_def))
            try: 
                for row in arr: 
                    cur.execute(sql, row )

            except Exception as e: 
                print(e)
                self.conn_pool.putconn(conn)
                raise e 
                

        conn.commit()
        self.conn_pool.putconn(conn)
        return 0