示例#1
0
def move_data(dbset):
    try:
        db = DB(**dbset)
        db.begin()
        query = db.query("SELECT bluetooth.move_raw_data();")
        if query.getresult()[0][0] != 1:
            db.rollback()
            raise DatabaseError(
                'bluetooth.move_raw_data did not complete successfully')
        query = db.query("TRUNCATE bluetooth.raw_data;")
        query = db.query("SELECT king_pilot.load_bt_data();")
        if query.getresult()[0][0] != 1:
            db.rollback()
            raise DatabaseError(
                'king_pilot.load_bt_data did not complete successfully')
        db.query(
            'DELETE FROM king_pilot.daily_raw_bt WHERE measured_timestamp < now()::DATE;'
        )
        db.commit()
    except DatabaseError as dberr:
        LOGGER.error(dberr)
        db.rollback()
    except IntegrityError:
        LOGGER.critical(
            'Moving data failed due to violation of a constraint. Data will have to be moved manually'
        )
    finally:
        db.close()
def main(**kwargs):
    CONFIG = configparser.ConfigParser()
    CONFIG.read('db.cfg')
    dbset = CONFIG['DBSETTINGS']

    logger.info('Connecting to Database')
    db = DB(dbname=dbset['database'],
            host=dbset['host'],
            user=dbset['user'],
            passwd=dbset['password'])
    proxies = {'https': kwargs.get('proxy', None)}

    # Update Venue List
    venues = []
    curId = db.query('SELECT max(id) FROM city.venues').getresult()[0][0]

    logger.info('Updating venues table')
    venues, inserted_venues = update_venues(db, proxies, curId)

    # Get Events from List of Venues
    #cla = []
    logger.info('Finished updating venues tables, %s new venues inserted',
                inserted_venues)

    inserted_count = update_events(db, proxies, venues)
    logger.info('Finished processing events, %s events inserted',
                inserted_count)
    db.close()
示例#3
0
def insert_data(data: list, dbset: dict, live: bool):
    '''
    Upload data to the database

    :param data:
        List of dictionaries, gets converted to list of tuples
    :param dbset:
        DB settings passed to Pygresql to create a connection 
    '''
    num_rows = len(data)
    if num_rows > 0:
        LOGGER.info('Uploading %s rows to PostgreSQL', len(data))
        LOGGER.debug(data[0])
    else:
        LOGGER.warning('No data to upload')
        return
    to_insert = []
    for dic in data:
        # convert each observation dictionary into a tuple row for inserting
        row = (dic["userId"], dic["analysisId"], dic["measuredTime"],
               dic["measuredTimeNoFilter"], dic["startPointNumber"],
               dic["startPointName"], dic["endPointNumber"],
               dic["endPointName"], dic["measuredTimeTimestamp"],
               dic["outlierLevel"], dic["cod"], dic["deviceClass"])
        to_insert.append(row)

    db = DB(**dbset)
    if live:
        db.inserttable('king_pilot.daily_raw_bt', to_insert)
    else:
        db.inserttable('bluetooth.raw_data', to_insert)
    db.close()
示例#4
0
def run_analyze(table):
    global counter
    db = DB(dbname = vDatabase, host = vHost)
    db.query('analyze %s' %table)
    with counter.get_lock():
        counter.value += 1
    if counter.value % 10 == 0 or counter.value == total_tables:
        logging.info(str(counter.value) + " tables completed out of " + str(total_tables) + " tables")
    db.close()
示例#5
0
def get_csv_data():
    """ Connect to the PostgreSQL database server """
    conn = None
    try:
        # read connection parameters
        params = config()

        # connect to the PostgreSQL server
        print('Connecting to the PostgreSQL database...')
        conn = DB(**params)

        # execute a statement
        print('PostgreSQL ALL accepted users who signed up for team matching')
        q = conn.query(
            "SELECT * FROM users JOIN event_applications ON users.id = event_applications.user_id WHERE custom_fields ->> 'team_forming' = 'Yes, sign me up!' AND status = 'accepted';"
        )

        data = q.dictresult()

        f = StringIO()
        writer = csv.writer(f,
                            delimiter=',',
                            quotechar='"',
                            quoting=csv.QUOTE_MINIMAL)
        writer.writerow(features)

        print(f'Adding {len(data)} entries to csv file')

        for row in data:
            str_id = str(row['user_id'])
            # get the MD5 hash of id
            result = hashlib.md5(str_id.encode())
            hashed_id = result.hexdigest()

            full_duration = (row['custom_fields'].get(
                'arrival_time', '') == "I'm staying for the entire event")
            user_features = [hashed_id, row['first_name'], row['last_name'], row['email'], row['phone'], full_duration, \
                           row['age'], row['pronoun'], row['university'], row['education_lvl'], row['major'], \
                           row['grad_year'], row['custom_fields'].get("travel", None), row['custom_fields'].get("programming_skills", None), \
                           row['custom_fields'].get("been_to_ttb", None), None, None, row['custom_fields'].get("linkedin_url", None), \
                           row['custom_fields'].get("github_url", None), row['custom_fields'].get("other_url", None), row['custom_fields'].get("how_did_you_hear", None), \
                           None, row['custom_fields'].get("programming_experience", None), row['custom_fields'].get("how_many_hackathons", None), None, \
                           row['custom_fields'].get("other_skills", None), row['custom_fields'].get("particular_topic", None), row['custom_fields'].get("goals", None), \
                           row['custom_fields'].get("experience_area", None), row['custom_fields'].get("teammate_preference", None), None]
            writer.writerow(user_features)
        # move the pointer back to beginning of file
        f.seek(0)
        return f
    except (Exception) as error:
        print("Error:", error)
    finally:
        if conn is not None:
            conn.close()
            print('Database connection closed.')
示例#6
0
def update_configs(all_analyses, dbset):
    '''
    Syncs configs from blip server with database and returns configs to pull 
    data from. 
    :param all_analyses:
        List of blip configurations
    :param dbset:
        Dictionary to connect to PostgreSQL database
    '''

    db = DB(**dbset)
    db.begin()
    db.query('''TRUNCATE bluetooth.all_analyses_day_old;
    INSERT INTO bluetooth.all_analyses_day_old SELECT * FROM bluetooth.all_analyses;'''
             )
    db.commit()
    analyses_pull_data = {}
    for report in all_analyses:
        report.outcomes = [outcome.__json__() for outcome in report.outcomes]
        report.routePoints = [
            route_point.__json__() for route_point in report.routePoints
        ]
        row = dict(device_class_set_name=report.deviceClassSetName,
                   analysis_id=report.id,
                   minimum_point_completed=db.encode_json(
                       report.minimumPointCompleted.__json__()),
                   outcomes=report.outcomes,
                   report_id=report.reportId,
                   report_name=report.reportName,
                   route_id=report.routeId,
                   route_name=report.routeName,
                   route_points=report.routePoints)
        #If upsert fails, log error and continue, don't add analysis to analyses to pull
        try:
            upserted = db.upsert('bluetooth.all_analyses',
                                 row,
                                 pull_data='included.pull_data')
            analyses_pull_data[upserted['analysis_id']] = {
                'pull_data': upserted['pull_data'],
                'report_name': upserted['report_name']
            }
        except IntegrityError as err:
            LOGGER.error(err)

    db.close()

    analyses_to_pull = {
        analysis_id: analysis
        for (analysis_id, analysis) in analyses_pull_data.items()
        if analysis['pull_data']
    }
    return analyses_to_pull
示例#7
0
def read_line():
	i = 0 
	reader = csv.reader(f, delimiter='\t')
	db = DB(dbname="ngram" ,user="******" , port=5432)
	for row in reader: 
		#ngram, year, match_count, page_count, volume_count
		sql = sqlp1 + " " + tb_name + " " + sqlp2 + "  (" + str(i) + ", " \
		+  "\'" + row[0] + "\'" + ", " + row[1] + ", " + row[2] + ", " + row[3] \
		+ ", " +  row[4] + ");"; 
		print sql
		i = i + 1
		#call insert(sql)
		db.query(sql)	
		db.close()	
示例#8
0
def get_tables():
    db = DB(dbname = vDatabase, host = vHost)
    table_list = []
    if options.usertables:
        table_list = db.get_tables()
    else:
        table_list = db.get_tables('system')
    db.close()

    if vSchema:
        tables = []
        regex = "^" + vSchema + "\."
        for table in table_list:
            if re.match(regex, table, re.I):
                tables.append(table)
    else:
        tables = table_list
    return tables
def main(**kwargs):

    CONFIG = configparser.ConfigParser()
    CONFIG.read('db.cfg')
    dbset = CONFIG['DBSETTINGS']

    logger.info('Connecting to Database')
    db = DB(dbname=dbset['database'],
            host=dbset['host'],
            user=dbset['user'],
            passwd=dbset['password'])

    proxies = {'http': kwargs.get('proxy', None)}

    logger.info('Requesting data')
    r = requests.get(
        'http://app.toronto.ca/cc_sr_v1_app/data/edc_eventcal_APR',
        proxies=proxies)

    events = r.json()

    global CURID, ODID
    CURID = db.query('SELECT max(id) FROM city.venues').getresult()[0][0]
    ODID = db.query('SELECT max(od_id) FROM city.od_venues').getresult()[0][0]

    logger.info('Processing events')
    inserted_events, inserted_venues, updated_venues = 0, 0, 0

    for i, entry0 in enumerate(events):
        try:
            inserted_venue, updated_venue = process_event(
                i, entry0['calEvent'], db)
            inserted_events += 1
            inserted_venues += inserted_venue
            updated_venues += updated_venue
        except KeyError as key_error:
            logger.error('Key error with event: %s, key %s, skipping',
                         entry0['calEvent'].get('eventName',
                                                ''), key_error.args[0])

    logger.info('%s events processed, %s venues inserted, %s venues updated',
                inserted_events, inserted_venues, updated_venues)
    logger.info('closing connection to DB')
    db.close()
示例#10
0
文件: util.py 项目: xunzhang/happyday
class PGDB(object):
    def __init__(self):
        pass

    def connect(self, dbname, host, port, user):
        self.db = DB(dbname=dbname, host=host, port=port, user=user)

    def connect_default(self):
        self.db = DB(dbname=DBNAME, host=HOST, port=PORT, user=USER)

    def close(self):
        self.db.close()
    
    def execute(self, sql):
        return self.db.query(sql)
    
    def drop_table(self, tbl):
        self.execute('DROP TABLE if exists %s' % tbl)

    def create_init_table(self, tbl):
        self.execute('CREATE TABLE %s (uid INT, iid INT, rating REAL)' % tbl)

    def copy(self, tbl, path, delimiter):
        self.execute("COPY %s FROM '%s' delimiter '%s'" % (tbl, path, delimiter))
示例#11
0
# Getting command line options

if options.database:
    vDatabase = options.database
else:
    logging.error("database not supplied... exiting...")
    sys.exit()

con = DB(dbname='template1', host=options.host)
if vDatabase in con.get_databases():
    pass
else:
    logging.error("Database doesn't exists... exiting")
    sys.exit()
con.close()

vProcesses = int(options.parallel)
vHost = options.host
vSchema = options.schema

# Function to get list of table

def get_tables():
    db = DB(dbname = vDatabase, host = vHost)
    table_list = []
    if options.usertables:
        table_list = db.get_tables()
    else:
        table_list = db.get_tables('system')
    db.close()
示例#12
0
class PostGreDBConnector:
    """PostGreDBConnector opens a PostGre DB connection. Different functions allow you to add, delete or update
    documents in PostGre DB."""

    def __init__(self):
        """Connecting to localhost (default host and port [localhost, 4532]) PostGre DB and initializing needed data
            base and tables."""
        try:
            print("Connecting to PostGre DB...")
            self.__db = DB(dbname='testdb', host='localhost', port=5432, user='******', passwd='superuser')
            print("PostGre DB connection successfully built.")
        except ConnectionError:
            print("PostGre DB connection could not be built.")

        self.delete_all_data()
        self.drop_all_tables()

    def close_connection(self):
        self.__db.close()

    def create_schema(self, schema_name):
        self.__db.query("CREATE SCHEMA " + schema_name)
        self.__create_tables(schema_name)
        self.__create_functions(schema_name)

    def __create_tables(self, schema):
        """Create needed tables for RDF parsing."""
        schema += "."
        self._add_table("CREATE TABLE " + schema + "texts (id serial primary key, title text)")
        self._add_table(
            "CREATE TABLE " + schema + "bscale (id serial primary key, bscale text, nominal bool, ordinal bool, interval bool)")
        self._add_table("CREATE TABLE " + schema + "bsort (id serial primary key, bsort text)")
        self._add_table("CREATE TABLE " + schema + "pattern (id serial primary key, pattern text)")
        self._add_table("CREATE TABLE " + schema + "single_pattern (id serial primary key, single_pattern text)")
        self._add_table("CREATE TABLE " + schema + "snippets (id serial primary key, snippet text)")

        # relations
        self._add_table("CREATE TABLE " + schema + "has_attribute (bsort_id int, bscale_id integer[], aggregation int)")
        self._add_table("CREATE TABLE " + schema + "has_object (bscale_id int, pattern_id integer[], aggregation int)")
        self._add_table(
            "CREATE TABLE " + schema + "pattern_single_pattern (pattern_id int, single_pattern_id integer[], aggregation int)")
        self._add_table("CREATE TABLE " + schema + "texts_snippets (text_id int primary key, snippet_id integer[], aggregation int)")
        self._add_table(
            "CREATE TABLE " + schema + "snippet_offsets (id serial primary key,"
            " single_pattern_id int, snippet_id int, offsets integer[][], aggregation int)")

        # adjective and verb extractions
        self._add_table("CREATE TABLE " + schema + "subject_occ (id serial primary key, subject text, count int)")
        self._add_table("CREATE TABLE " + schema + "adjective_occ (id serial primary key, adjective text, count int)")
        self._add_table("CREATE TABLE " + schema + "verb_occ (id serial primary key, verb text, count int)")
        self._add_table("CREATE TABLE " + schema + "object_occ (id serial primary key, object text, count int)")
        self._add_table("CREATE TABLE " + schema + "subject_adjective_occ (id serial primary key, subject int, adjective int, count int, pmi float)")
        self._add_table("CREATE TABLE " + schema + "subject_object_occ (id serial primary key, subject int, object int, count int, pmi float)")
        self._add_table("CREATE TABLE " + schema + "object_verb_occ (id serial primary key, object int, verb int, count int, pmi float)")
        self._add_table("CREATE TABLE " + schema + "subject_verb_occ (id serial primary key, subject int, verb int, count int, pmi float)")

        # correlating pattern
        self._add_table("CREATE TABLE " + schema + "bscale_single_pattern (id serial primary key, bscale_id int, single_pattern_id int, single_pattern text, count int)")
        self._add_table(
            "CREATE TABLE " + schema + "correlating_pattern (id serial primary key, pattern_a int, pattern_b int, count int, pmi float)")

    def __create_functions(self, schema):
        """Create all necessary functions to aggregate the results saved in the database."""
        schema += "."
        self.add_function(schema + "aggregate_texts_snippets", "SELECT text_id, array_length(snippet_id, 1) FROM " + schema + "texts_snippets")
        self.add_function(schema + "aggregate_snippet_offsets", "SELECT id, array_length(offsets, 1) FROM " + schema + "snippet_offsets")

    def add_function(self, name, function):
        """Create a new function in the db."""
        create_function = "CREATE FUNCTION "
        returns = "() RETURNS SETOF RECORD AS "
        lang = " LANGUAGE SQL"
        query = create_function + name + returns + add_quotes(function) + lang
        self.__db.query(query)

    def _add_table(self, query):
        """Create a new table with a query."""
        self.__db.query(query)

    def add_table(self, schema, name, rows):
        """Create a new table with a name and rows given in query form."""
        create_table = "CREATE TABLE "
        query = create_table + schema + "." + name + rows
        self.__db.query(query)

    def insert(self, schema, table, row):
        """Insert a new row element into a specified table."""
        return self.__db.insert(schema + "." + table, row)

    def is_in_table(self, schema, table, where_clause):
        """Returns whether a row already exists in a table or not."""
        select = "SELECT * FROM "
        where = " WHERE "
        q = select + schema + "." + table + where + where_clause
        result = self.__db.query(q).dictresult()
        if len(result) > 0:
            return True
        else:
            return False

    def update(self, schema, table, values, where_clause):
        """Update an entry in a specified table."""
        UPDATE = "UPDATE "
        SET = " SET "
        WHERE = " WHERE "
        query = UPDATE + schema + "." + table + SET + values + WHERE + where_clause
        self.query(query)

    def get(self, schema, table, where_clause, key):
        """Return the key of a specific item in a table."""
        select = "SELECT "
        _from = " FROM "
        where = " WHERE "
        q = select + key + _from + schema + "." + table + where + where_clause
        result = self.__db.query(q).dictresult()
        if len(result) > 0:
            return result[0][key]
        else:
            return None

    def get_data_from_table(self, schema, table):
        """Gets all data available in a specific table."""
        return self.__db.query("SELECT * FROM " + schema + "." + table).dictresult()

    def get_id(self, schema, table, where_clause):
        """Return the id of an item in a table. If found return id number of found item, else None."""
        select = "SELECT id FROM "
        where = " WHERE "
        q = select + schema + "." + table + where + where_clause
        result = self.__db.query(q).dictresult()
        if len(result) > 0:
            return result[0]['id']
        else:
            return None

    def delete_from_table(self, schema, table, row):
        """Delete a row element form a specific table."""
        return self.__db.delete(schema + "." + table, row)

    def delete_data_in_table(self, schema, table):
        """Delete all data in a specific table."""
        self.__db.truncate(schema + "." + table, restart=True, cascade=True, only=False)

    def delete_all_data(self):
        """Deletes all data from all existing tables."""
        tables = self.get_tables()
        for table in tables:
            table_name = str(table)
            self.__db.truncate(table_name, restart=True, cascade=True, only=False)

    def get_tables(self):
        """Get all available tables in the database."""
        return self.__db.get_tables()

    def get_attributes(self, schema, table):
        """Get all attributes of a specified table."""
        return self.__db.get_attnames(schema + "." + table)

    def drop_table(self, schema, table):
        """Drops a specified table."""
        query = "DROP TABLE "
        self.__db.query(query + schema + "." + table)

    def drop_all_tables(self):
        """Drops all existing tables."""
        tables = self.get_tables()
        table_names = ""
        if len(tables) > 0 :
            for ind, table in enumerate(tables):
                if ind == 0:
                    table_names = str(table)
                else:
                    table_names = table_names + ", " + str(table)
            self.__db.query("DROP TABLE " + table_names)
        else:
            print("Nothing to delete.")

    def get_all(self, schema, table, attribute):
        """Gets one or more attributes of all entries from a specified table."""
        select = "SELECT "
        _from = " FROM "
        query = select + attribute + _from + schema + "." + table
        return self.__db.query(query).dictresult()

    def query(self, query):
        """Sends a query to the database."""
        result = self.__db.query(query)
        if result is not None:
            if not isinstance(result, str):
                return result.dictresult()
        else:
            return result
示例#13
0
class BotDB:
    def __init__(self, db_url):
        urlparse.uses_netloc.append("postgres")
        self.__db_url = db_url
        url = urlparse.urlparse(db_url)
        self.__db = DB(
            dbname=url.path[1:],
            user=url.username,
            passwd=url.password,
            host=url.hostname,
            port=url.port
        )

    def insertThesis(self, init_id, chat_id, user_id, body):
        ts = time.time()
        timestamp = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
        print("inserting thesis")
        print(init_id, chat_id, user_id, body, timestamp)
        self.__db.insert("theses", row={"init_id": init_id, "chat_id": chat_id, "user_id": user_id, "body": body})
        print("done")
        self.__db.commit()

    def getThesisByIds(self, init_id, chat_id):
        query = self.__db.query("SELECT * FROM theses WHERE init_id = %d AND chat_id = %d;" % (init_id, chat_id))
        dict_res = query.dictresult()
        if len(dict_res) == 0:
            return False
        else:
            return dict_res[0]

    def getThesisByBody(self, body):
        query = self.__db.query("SELECT * FROM theses WHERE body = '%s';" % body)
        dict_res = query.dictresult()
        if len(dict_res) == 0:
            return False
        else:
            return dict_res[0]

    def getLastThesesByTime(self, chat_id, interval):
        query = self.__db.query("SELECT * FROM theses WHERE chat_id = %s AND creation_time > current_timestamp - interval '%s';" % (chat_id, interval))
        dict_res = query.dictresult()
        if len(dict_res) == 0:
            return False
        else:
            return dict_res

    def getTodayTheses(self, chat_id):
        query = self.__db.query("SELECT * FROM theses WHERE chat_id = %s AND creation_time > current_date;" % chat_id)
        dict_res = query.dictresult()
        if len(dict_res) == 0:
            return False
        else:
            return dict_res

    def insertUser(self, user_id, username, first_name, last_name):
        # ts = time.time()
        row = {"user_id":user_id}
        if username:
            row["username"] = username
        if first_name:
            row["first_name"] = first_name
        if last_name:
            row["last_name"] = last_name
        self.__db.insert('users', row=row)
        self.__db.commit()

    def getUserById(self, user_id):
        query = self.__db.query("SELECT * FROM users WHERE user_id = %d;" % user_id)
        dict_res = query.dictresult()
        if len(dict_res) == 0:
            return False
        else:
            return dict_res[0]

    def insertBotMessage(self, chat_id, message_id, owner_id):
        row = {"chat_id": chat_id, "message_id": message_id, "owner_id": owner_id}
        self.__db.insert('bot_messages', row=row)
        self.__db.commit()

    def getBotMessage(self, chat_id, message_id):
        query = self.__db.query("SELECT * FROM bot_messages WHERE chat_id = %d AND message_id = %d;" % (chat_id, message_id))
        dict_res = query.dictresult()
        if len(dict_res) == 0:
            return False
        else:
            return dict_res[0]

    def close(self):
        self.__db.close()
示例#14
0
class SteadyPgConnection:
    """Class representing steady connections to a PostgreSQL database.

    Underlying the connection is a classic PyGreSQL pg API database
    connection which is reset if the connection is lost or used too often.
    Thus the resulting connection is steadier ("tough and self-healing").

    If you want the connection to be persistent in a threaded environment,
    then you should not deal with this class directly, but use either the
    PooledPg module or the PersistentPg module to get the connections.

    """

    version = __version__

    def __init__(self, maxusage=None, setsession=None, closeable=True,
            *args, **kwargs):
        """Create a "tough" PostgreSQL connection.

        maxusage: maximum usage limit for the underlying PyGreSQL connection
            (number of uses, 0 or None means unlimited usage)
            When this limit is reached, the connection is automatically reset.
        setsession: optional list of SQL commands that may serve to prepare
            the session, e.g. ["set datestyle to ...", "set time zone ..."]
        closeable: if this is set to false, then closing the connection will
            be silently ignored, but by default the connection can be closed
        args, kwargs: the parameters that shall be used to establish
            the PostgreSQL connections with PyGreSQL using pg.DB()

        """
        # basic initialization to make finalizer work
        self._con = None
        self._closed = True
        # proper initialization of the connection
        if maxusage is None:
            maxusage = 0
        if not isinstance(maxusage, int):
            raise TypeError("'maxusage' must be an integer value.")
        self._maxusage = maxusage
        self._setsession_sql = setsession
        self._closeable = closeable
        self._con = PgConnection(*args, **kwargs)
        self._transaction = False
        self._closed = False
        self._setsession()
        self._usage = 0

    def _setsession(self):
        """Execute the SQL commands for session preparation."""
        if self._setsession_sql:
            for sql in self._setsession_sql:
                self._con.query(sql)

    def _close(self):
        """Close the tough connection.

        You can always close a tough connection with this method
        and it will not complain if you close it more than once.

        """
        if not self._closed:
            try:
                self._con.close()
            except Exception:
                pass
            self._transaction = False
            self._closed = True

    def close(self):
        """Close the tough connection.

        You are allowed to close a tough connection by default
        and it will not complain if you close it more than once.

        You can disallow closing connections by setting
        the closeable parameter to something false. In this case,
        closing tough connections will be silently ignored.

        """
        if self._closeable:
            self._close()
        elif self._transaction:
            self.reset()

    def reopen(self):
        """Reopen the tough connection.

        It will not complain if the connection cannot be reopened.

        """
        try:
            self._con.reopen()
        except Exception:
            if self._transcation:
                self._transaction = False
                try:
                    self._con.query('rollback')
                except Exception:
                    pass
        else:
            self._transaction = False
            self._closed = False
            self._setsession()
            self._usage = 0

    def reset(self):
        """Reset the tough connection.

        If a reset is not possible, tries to reopen the connection.
        It will not complain if the connection is already closed.

        """
        try:
            self._con.reset()
            self._transaction = False
            self._setsession()
            self._usage = 0
        except Exception:
            try:
                self.reopen()
            except Exception:
                try:
                    self.rollback()
                except Exception:
                    pass

    def begin(self, sql=None):
        """Begin a transaction."""
        self._transaction = True
        try:
            begin = self._con.begin
        except AttributeError:
            return self._con.query(sql or 'begin')
        else:
            # use existing method if available
            if sql:
                return begin(sql=sql)
            else:
                return begin()

    def end(self, sql=None):
        """Commit the current transaction."""
        self._transaction = False
        try:
            end = self._con.end
        except AttributeError:
            return self._con.query(sql or 'end')
        else:
            if sql:
                return end(sql=sql)
            else:
                return end()

    def commit(self, sql=None):
        """Commit the current transaction."""
        self._transaction = False
        try:
            commit = self._con.commit
        except AttributeError:
            return self._con.query(sql or 'commit')
        else:
            if sql:
                return commit(sql=sql)
            else:
                return commit()

    def rollback(self, sql=None):
        """Rollback the current transaction."""
        self._transaction = False
        try:
            rollback = self._con.rollback
        except AttributeError:
            return self._con.query(sql or 'rollback')
        else:
            if sql:
                return rollback(sql=sql)
            else:
                return rollback()

    def _get_tough_method(self, method):
        """Return a "tough" version of a connection class method.

        The tough version checks whether the connection is bad (lost)
        and automatically and transparently tries to reset the connection
        if this is the case (for instance, the database has been restarted).

        """
        def tough_method(*args, **kwargs):
            transaction = self._transaction
            if not transaction:
                try: # check whether connection status is bad
                    if not self._con.db.status:
                        raise AttributeError
                    if self._maxusage: # or connection used too often
                        if self._usage >= self._maxusage:
                            raise AttributeError
                except Exception:
                    self.reset() # then reset the connection
            try:
                result = method(*args, **kwargs) # try connection method
            except Exception: # error in query
                if transaction: # inside a transaction
                    self._transaction = False
                    raise # propagate the error
                elif self._con.db.status: # if it was not a connection problem
                    raise # then propagate the error
                else: # otherwise
                    self.reset() # reset the connection
                    result = method(*args, **kwargs) # and try one more time
            self._usage += 1
            return result
        return tough_method

    def __getattr__(self, name):
        """Inherit the members of the standard connection class.

        Some methods are made "tougher" than in the standard version.

        """
        if self._con:
            attr = getattr(self._con, name)
            if (name in ('query', 'get', 'insert', 'update', 'delete')
                    or name.startswith('get_')):
                attr = self._get_tough_method(attr)
            return attr
        else:
            raise InvalidConnection

    def __del__(self):
        """Delete the steady connection."""
        try:
            self._close() # make sure the connection is closed
        except Exception:
            pass
示例#15
0
class TPCDS:

    STORAGE = [
        ("small_storage", "appendonly=true, orientation=column"),
        ("medium_storage", "appendonly=true, orientation=column, compresstype=zstd"),
        ("large_storage", "appendonly=true, orientation=column, compresstype=zstd"),
        ("e9_medium_storage", "appendonly=true, compresstype=zstd"),
        ("e9_large_storage", "appendonly=true, orientation=column, compresstype=zstd")
    ]

    RAW_DATA_PATH = "/data/generated_source_data/data"

    def __init__(self, info_dir, dbname, port, host, data_path):
        self.info_dir = info_dir
        self.db = DB(dbname=dbname, port=port, host=host)
        self.dist_info = self.parse_dist_info()
        self.data_path = data_path

    def create_schema(self):
        sqls = [
            "DROP SCHEMA IF EXISTS tpcds CASCADE;",
            "DROP SCHEMA IF EXISTS ext_tpcds CASCADE;",
            "CREATE SCHEMA tpcds;",
            "CREATE SCHEMA ext_tpcds;"
        ]
        for sql in sqls:
            self.db.query(sql)

    def create_table(self):
        ddl_top_path = os.path.join(self.info_dir, "ddl")
        print("creating norm tables...")
        for fn in os.listdir(ddl_top_path):
            if not fn.endswith(".sql"): continue
            if fn == "000.e9.tpcds.sql": continue
            if "ext_" in fn: continue
            self.create_normal_table(os.path.join(ddl_top_path, fn))
        print("norm tables created.")

        print("creating ext tables...")        
        for fn in os.listdir(ddl_top_path):
            if not fn.endswith(".sql"): continue
            if fn == "000.e9.tpcds.sql": continue
            if "ext_" not in fn: continue
            self.create_ext_table(os.path.join(ddl_top_path, fn))
        print("ext tables created.")

    def create_normal_table(self, ddlpath):
        assert("ext_" not in ddlpath)
        with open(ddlpath) as f:
            sql = f.read().lower()
        tabname = self.get_tabname_from_path(ddlpath)
        sql = self.patch_dist_info(sql, tabname)
        sql = self.patch_storage_info(sql)
        self.db.query(sql)

    def create_ext_table(self, ddlpath):
        assert("ext_" in ddlpath)
        with open(ddlpath) as f:
            sql = f.read().lower()
        tabname = self.get_tabname_from_path(ddlpath)
        sql = self.patch_gpfdist_local(sql, tabname)
        self.db.query(sql)

    def patch_dist_info(self, sql, tabname):
        distkeys = self.dist_info[tabname]
        sql = sql.replace(":distributed_by",
                          "distributed by (%s)" % distkeys)
        return sql

    def patch_storage_info(self, sql):
        for storage_type, storage_option in self.STORAGE:
            replace_key = ":" + storage_type
            sql = sql.replace(replace_key, storage_option)
        return sql

    def patch_gpfdist_local(self, sql, tabname):
        url = "'gpfdist://mdw:2223/%s*.dat'" % tabname
        sql = sql.replace(":location", url)
        return sql
    
    def get_tabname_from_path(self, path):
        filename = os.path.basename(path)
        return filename.split(".")[-2]

    def parse_dist_info(self):
        dist_info = {}
        with open(os.path.join(self.info_dir, "ddl", "distribution.txt")) as f:
            for line in f:
                _, tabname, distkeys = line.strip().split("|")
                dist_info[tabname] = distkeys
        return dist_info

    def start_gpfdist(self, port, logfile):
        cmd = ["gpfdist",
               "-p", str(port),
               "-l", logfile]
        proc = subprocess.Popen(cmd)
        sleep(3)
        return proc

    def load_all_tables(self):
        ddl_top_path = os.path.join(self.info_dir, "ddl")
        start_time = time()
        print("load tables...")
        print("===================================")
        for fn in os.listdir(ddl_top_path):
            if not fn.endswith(".sql"): continue
            if fn == "000.e9.tpcds.sql": continue
            if "ext_" in fn: continue
            self.load_table(os.path.join(ddl_top_path, fn))
        end_time = time()
        cost = end_time - start_time
        print("all tables finished in %s seconds" % cost)

    def load_table(self, ddlpath):
        tabname = self.get_tabname_from_path(ddlpath)
        tab_data_path = self.get_tab_data_path(tabname)
        os.chdir(tab_data_path)
        proc = self.start_gpfdist("2223", "/data/gpfdist.log")
        sql = ("insert into tpcds.%s "
               "select * from ext_tpcds.%s") % (tabname, tabname)
        start_time = time()
        self.db.query(sql)
        end_time = time()
        proc.terminate()
        proc.wait()
        cost = end_time - start_time
        print("load %s cost time %s seconds" % (tabname, cost))
        print("===================================")

    def get_tab_data_path(self, tabname):
        return os.path.join(self.data_path, tabname)

    def move_all_tables(self):
        ddl_top_path = os.path.join(self.info_dir, "ddl")
        for fn in os.listdir(ddl_top_path):
            if not fn.endswith(".sql"): continue
            if fn == "000.e9.tpcds.sql": continue
            if "ext_" in fn: continue
            self.move_table(os.path.join(ddl_top_path, fn))

    def move_table(self, ddlpath):
        tabname = self.get_tabname_from_path(ddlpath)
        fns = self.findall_tab_data(tabname)
        newplace = os.path.join(self.data_path, tabname)
        os.makedirs(newplace, exist_ok=True)
        for fn in fns:
            shutil.move(fn, newplace)
        print("moving %s (%d dat files)" % (tabname, len(fns)))

    def findall_tab_data(self, tabname):
        data_fns = []
        pt_regstr = tabname + r"_\d+_\d+.dat"
        pt = re.compile(pt_regstr)
        for fn in os.listdir(self.RAW_DATA_PATH):
            if pt.search(fn):
                data_fns.append(os.path.join(self.RAW_DATA_PATH, fn))
        return data_fns

    def close_db(self):
        self.db.close()
示例#16
0
class SolidPgConnection:
    """Class representing solid connections to a PostgreSQL database.

	Underlying the connection is a classic PyGreSQL pg API database
	connection which is reset if the connection is lost or used too often.
	Thus the resulting connection is more solid ("tough and self-healing").

	If you want the connection to be persistent in a threaded environment,
	then you should not deal with this class directly, but use either the
	PooledPg module or the PersistentPg module to get the connections.

	"""
    def __init__(self, maxusage=0, setsession=None, *args, **kwargs):
        """Create a "tough" PostgreSQL connection.

		maxusage: maximum usage limit for the underlying PygreSQL connection
			(number of uses, 0 or False means unlimited usage)
			When this limit is reached, the connection is automatically reset.
		setsession: optional list of SQL commands that may serve to prepare
			the session, e.g. ["set datestyle to ...", "set time zone ..."]
		args, kwargs: the parameters that shall be used to establish
			the PostgreSQL connections with PyGreSQL using pg.DB()
		"""
        self._maxusage = maxusage
        self._setsession_sql = setsession
        self._usage = 0
        self._con = PgConnection(*args, **kwargs)
        self._setsession()

    def _setsession(self):
        """Execute the SQL commands for session preparation."""
        if self._setsession_sql:
            for sql in self._setsession_sql:
                self._con.query(sql)

    def close(self):
        """Close the tough connection.

		You are allowed to close a tough connection.
		It will not complain if you close it more than once.
		"""
        try:
            self._con.close()
            self._usage = 0
        except:
            pass

    def reopen(self):
        """Reopen the tough connection.

		It will not complain if the connection cannot be reopened."""
        try:
            self._con.reopen()
            self._setsession()
            self._usage = 0
        except:
            pass

    def reset(self):
        """Reset the tough connection.

		If a reset is not possible, tries to reopen the connection.
		It will not complain if the connection is already closed.
		"""
        try:
            self._con.reset()
            self._setsession()
            self._usage = 0
        except:
            self.reopen()

    def _get_tough_method(self, method):
        """Return a "tough" version of a connection class method.

		The tough version checks whether the connection is bad (lost)
		and automatically and transparently tries to reset the connection
		if this is the case (for instance, the database has been restarted).
		"""
        def tough_method(*args, **kwargs):
            try:  # check whether connection status is bad
                if not self._con.db.status:
                    raise AttributeError
                if self._maxusage:  # or connection used too often
                    if self._usage >= self._maxusage:
                        raise AttributeError
            except:
                self.reset()  # then reset the connection
            try:
                r = method(*args, **kwargs)  # try connection method
            except:  # error in query
                if self._con.db.status:  # if it was not a connection problem
                    raise  # then propagate the error
                else:  # otherwise
                    self.reset()  # reset the connection
                    r = method(*args, **kwargs)  # and try one more time
            self._usage += 1
            return r

        return tough_method

    def __getattr__(self, name):
        """Inherit the members of the standard connection class.

		Some methods are made "tougher" than in the standard version.
		"""
        attr = getattr(self._con, name)
        if name in ('query', 'get', 'insert', 'update', 'delete') \
         or name.startswith('get_'):
            attr = self._get_tough_method(attr)
        return attr
示例#17
0
def migrateNow():
    snsClient = None
    print(
        'If you wish to recieve notifications of various stages in the process please enter an ARN of a sns topic below.(please leave it empty if you dont want any notifications)\nARN:'
    )
    snsArn = input().strip()
    snsProvided = False

    if (snsArn != '' and isSNSValid(x=snsArn)):
        snsProvided = True

    if (snsProvided):
        snsClient = boto3.client('sns', snsArn.split(':')[3])
        snsClient.get_topic_attributes(TopicArn=snsArn)
        snsClient.publish(
            TopicArn=snsArn,
            Message=
            "Dear User, You will now now recieve notifications for migration process."
        )

    print("\n\t\t\t\t\t\tPlease enter source details (unencrypted cluster) : ")
    print("\t\t\t\t\t\t===============================================")

    source = Credentials()
    source.readCreds()

    print(
        "\n\t\t\t\t\t\tPlease enter Destination details (encrypted cluster) : "
    )
    print("\t\t\t\t\t\t===============================================")

    destination = Credentials()
    destination.readCreds()
    print("Validating Endpoints")

    srv = isEndpointValid(source.ename)

    dsv = isEndpointValid(destination.ename)

    print('\n Fetching required scripts for migration, dont press any key')
    # Get the AdminViews sql scripts in the current working directory
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_generate_schema_ddl.sql',
        'v_generate_schema_ddl.sql')
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_generate_group_ddl.sql',
        'v_generate_group_ddl.sql')
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_get_users_in_group.sql',
        'v_get_users_in_group.sql')
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_generate_tbl_ddl.sql',
        'v_generate_tbl_ddl.sql')
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_generate_view_ddl.sql',
        'v_generate_view_ddl.sql')
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_get_schema_priv_by_user.sql',
        'v_get_schema_priv_by_user.sql')
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_get_tbl_priv_by_user.sql',
        'v_get_tbl_priv_by_user.sql')
    urllib.request.urlretrieve(
        'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_get_view_priv_by_user.sql',
        'v_get_view_priv_by_user.sql')

    temp_views = [
        'v_generate_tbl_ddl.sql', 'v_generate_view_ddl.sql',
        'v_generate_group_ddl.sql', 'v_get_users_in_group.sql',
        'v_get_schema_priv_by_user.sql', 'v_get_view_priv_by_user.sql',
        'v_generate_schema_ddl.sql', 'v_get_tbl_priv_by_user.sql'
    ]

    rolename = 'MigrationPolicy' + randint(1, 10000000000000).__str__()

    try:
        if (srv and dsv):

            print(
                'Enter bucket name where you want to store and load your backup from:'
            )
            s3bucket = input()

            print("\nCommon Key\n--------------")
            print(
                "\n please enter a common password to assign for all the users intially in Encrypted cluster as we cannoot migrate passwords"
            )
            print(
                "\nNote:\n=====\nShould atleast contain 8 ascii characters with one upper case  and one lower case letter\n"
            )
            print("Enter:")
            commonkey = getpass.getpass()

            print('Creating permissions for migration using role :' + rolename)
            client = boto3.client('iam')
            RoleTrustpolicy = '''{
             "Version": "2012-10-17",
             "Statement": {
             "Effect": "Allow",
             "Principal":{"Service": "redshift.amazonaws.com"},
             "Action": "sts:AssumeRole"
             }
             }'''

            response = client.create_role(
                AssumeRolePolicyDocument=RoleTrustpolicy,
                Path='/',
                RoleName=rolename,
            )
            roleArn = response["Role"]["Arn"]
            # print(response["Role"]["Arn"])

            accessPolicy = '''{
                 "Version": "2012-10-17",
                 "Statement": [
                     {
                         "Effect": "Allow",
                         "Action": "s3:*",
                         "Resource": "*"
                     }
                 ]
             }'''
            accesspolicyname = 's3fullonreadwriteandsavecopyunloadmyredshiftdata' + rolename
            response = client.create_policy(
                PolicyName=accesspolicyname,
                Path='/',
                PolicyDocument=accessPolicy,
                Description='to see and check s3 data')
            policyarn = response['Policy']['Arn']
            # print(policyarn)

            response = client.attach_role_policy(RoleName=rolename,
                                                 PolicyArn=policyarn)

            src = boto3.client('redshift', source.ename.split('.')[-4])

            response = src.modify_cluster_iam_roles(
                ClusterIdentifier=source.ename.split('.')[0],
                AddIamRoles=[roleArn])

            dst = boto3.client('redshift', destination.ename.split('.')[-4])

            response = dst.modify_cluster_iam_roles(
                ClusterIdentifier=destination.ename.split('.')[0],
                AddIamRoles=[roleArn])

            while (src.describe_clusters(ClusterIdentifier=source.ename.split(
                    '.')[0])['Clusters'][0]['ClusterStatus'] != 'available'
                   or dst.describe_clusters(
                       ClusterIdentifier=destination.ename.split('.')[0])
                   ['Clusters'][0]['ClusterStatus'] != 'available'):
                print('One or more Clusters in Modifying State')
                sleep(15)
                print('One or more Clusters in Modifying State')

            print(
                'attached role with s3  access policy to both redshift clusters, role name '
                + rolename +
                ' please dont modify it, it will be autodeleted once the process is ComPleted'
            )
            role = roleArn

            if (isBucketValid(s3bucket)):

                # Connect to source unencrypted database
                db = DB(dbname=source.dname,
                        host=source.ename,
                        port=source.port,
                        user=source.uname,
                        passwd=source.pwd)

                if (snsProvided):
                    snsClient.publish(TopicArn=snsArn,
                                      Message="Logged in successfully in" +
                                      source.ename.split('.')[0] + " cluster ")

                # Create the admin views using the above scripts
                # Below redundant statements can be replaced by a simple for loop

                # Create admin schema (if running on newly created source cluster)
                db.query("CREATE SCHEMA IF NOT EXISTS admin;"
                         )  # already present error only once

                # Creating the schemas, views, tables and groups using ddl statements is very important
                # Otherwise how and where would you copy the data ?
                # TODO: generate the DDL for schemas, tables and views

                print("\nRunnning backup scripts ......")
                for filename in temp_views:
                    file = open(filename, 'r')

                    str = file.read()
                    db.query('drop view if exists admin.{0};'.format(
                        filename.split('.')[0]))
                    db.query(str)

                print("\nFetching Details from Source ......")

                dqueries = []

                # Get DDL for Schema
                q = db.query(
                    "SELECT schemaname FROM admin.v_generate_schema_ddl where schemaname <>'admin';"
                )
                # print(q.getresult())

                for queries in q.getresult():
                    for query in queries:
                        dqueries.append("CREATE SCHEMA IF NOT EXISTS " +
                                        query + ";")

                # Get DDL for user tables
                q = db.query(
                    "SELECT DISTINCT tablename as table, schemaname FROM admin.v_generate_tbl_ddl WHERE schemaname <> 'information_schema' AND schemaname <> 'pg_catalog';"
                )

                print(" \n Tables in Source Cluster \n:")
                # print(q.getresult())

                user_tables = []
                tables = []

                for table in q.getresult():
                    tables.append(table[0])
                    user_tables.append(table[1] + "." + table[0])

                # print(user_tables)

                print("\nuploading source tables to S3 ......")
                # upload user tables to s3
                for table in user_tables:
                    print('Uploading Table : ' + table + " .......")
                    db.query("UNLOAD('select * from " + table +
                             "') TO 's3://" + s3bucket + "/" + table +
                             "/part_1' iam_role '" + role +
                             "' manifest ALLOWOVERWRITE;")

                print("Upload Complete")
                if (snsProvided):
                    snsClient.publish(TopicArn=snsArn,
                                      Message="Copied data into s3 bucket" +
                                      s3bucket)

                for table in tables:
                    q = db.query(
                        "SELECT ddl FROM admin.v_generate_tbl_ddl WHERE tablename='"
                        + table + "'")
                    # print(q.getresult())
                    create_table_query = []
                    for query in q.getresult():
                        for ddl in query:
                            create_table_query.append(ddl)
                    dqueries.append(''.join(create_table_query[1:]))

                # q = db.query('CREATE TABLE IF NOT EXISTS public.fruits(\tid INTEGER NOT NULL  ENCODE lzo\t,name VARCHAR(256)   ENCODE lzo\t,PRIMARY KEY (id))DISTSTYLE EVEN;')

                # Get DDL for views
                # CREATE OR REPLACE VIEW admin.v_generate_tbl_ddl  ---- not executing properly, else all executing properly
                print("\n Fectching Views in Source")
                q = db.query(
                    "SELECT DISTINCT viewname FROM admin.v_generate_view_ddl WHERE schemaname <> 'information_schema' AND schemaname <> 'pg_catalog' AND  schemaname <> 'admin';"
                )

                #print("\n Views in Source \n:")

                user_views = []

                for views in q.getresult():
                    for viewname in views:
                        user_views.append(viewname)

                # print(user_views)

                for view in user_views:
                    q = db.query(
                        "SELECT ddl FROM admin.v_generate_view_ddl WHERE viewname='"
                        + view + "'")
                    create_view_query = []
                    for query in q.getresult():
                        for ddl in query:
                            create_view_query = ddl.splitlines()

                    # print(' '.join(create_view_query[1:]))
                    dqueries.append(' '.join(create_view_query[1:]))

                print("\nGetting Users in Source")
                # Get the users
                # Maybe prompt for setting the password for users on the encrypted redshift cluster
                q = db.query(
                    "SELECT usename  FROM pg_user WHERE usename <> 'rdsdb' and usename<>'"
                    + source.uname + "';").getresult()
                # print(q)
                users = []
                for u in q:
                    for name in u:
                        users.append(name)
                # print(users)
                q = db.query(
                    "SELECT 'CREATE USER '|| usename || '' FROM pg_user WHERE usename <> 'rdsdb' and usename <> '"
                    + source.uname + "';")

                for queries in q.getresult():
                    for query in queries:
                        dqueries.append(query + " with password '" +
                                        commonkey + "';")

                print("\nGetting groups from Source")
                # Get the groups
                q = db.query("select groname FROM pg_group;").getresult()
                groups = []
                for g in q:
                    for name in g:
                        groups.append(name)

                q = db.query(
                    "SELECT 'CREATE GROUP  '|| groname  ||';' FROM pg_group;")

                for queries in q.getresult():
                    for query in queries:
                        dqueries.append(query)

                print("\nGetting Users and Group correlations from Source")
                # Get users in the groups
                q = db.query(
                    "SELECT 'ALTER GROUP ' ||groname||' ADD USER '||usename||';' FROM admin.v_get_users_in_group;"
                )
                # print(q.getresult())

                for queries in q.getresult():
                    for query in queries:
                        dqueries.append(query)

                # Python udf used to generate permissions for schema as concatenated strings
                udf = """create or replace function 
               f_schema_priv_granted(cre boolean, usg boolean) returns varchar
               STABLE 
               AS $$
                  priv = ''
                  prev = False;
                  if cre:
                      priv = str('create')
                      prev = True
                  if usg:
                   if prev:
                       priv = priv + str(', usage')
                   else :
                       priv = priv + str(' usage')

                  return priv
               $$LANGUAGE plpythonu;"""

                db.query(udf)

                print("Fetching Schema previleges")
                # Get schema privileges per user
                q = db.query(
                    """SELECT 'GRANT '|| f_schema_priv_granted(cre, usg) ||' ON schema '|| schemaname || ' TO ' || usename || ';' 
               FROM admin.v_get_schema_priv_by_user 
               WHERE schemaname NOT LIKE 'pg%' 
               AND schemaname <> 'information_schema'
               AND schemaname <> 'admin'
               AND usename <> 'rdsdb'AND usename <>'""" + source.uname + "';")

                for queries in q.getresult():
                    for query in queries:
                        dqueries.append(query)
                print("\nFetching Permissions Per USER for tables and views ")
                # Python udf used to generate permissions for table and views as concatenated strings
                udf = """create or replace function 
               f_table_priv_granted(sel boolean, ins boolean, upd boolean, delc boolean, ref boolean) returns varchar
               STABLE 
               AS $$
                  priv = ''
                  prev = False;
                  if sel:
                       priv = str('select')
                       prev = True;

                  if ins:
                   if prev :
                       priv = priv + str(', insert')
                   else :
                       priv = priv + str(' insert')
                       prev = True
                  if upd:
                   if prev:
                       priv = priv + str(', update')
                   else:
                       priv = priv + str(' update')
                       prev = True

                  if delc:
                   if prev:
                       priv = priv + str(', delete')
                   else :
                       priv = priv + str(' delete')
                       prev = True

                  if ref:
                   if prev:
                       priv = priv + str(', references ')
                   else:
                       priv = priv+str(' references')
                       prev = True
                  return priv
               $$LANGUAGE plpythonu;"""

                db.query(udf)

                # Get table privileges per user
                q = db.query(
                    """SELECT 'GRANT '|| f_table_priv_granted(sel, ins, upd, del, ref) || ' ON '|| 
               schemaname||'.'||tablename ||' TO '|| usename || ';' FROM admin.v_get_tbl_priv_by_user 
               WHERE schemaname NOT LIKE 'pg%' 
               AND schemaname <> 'information_schema'
               AND schemaname <> 'admin'
               AND usename <> 'rdsdb'
               AND usename <>'""" + source.uname + "';")

                for queries in q.getresult():
                    for query in queries:
                        dqueries.append(query)

                # Get view privileges per user
                q = db.query(
                    """SELECT 'GRANT '|| f_table_priv_granted(sel, ins, upd, del, ref) || ' ON '|| 
               schemaname||'.'||viewname ||' TO '|| usename || ';' FROM admin.v_get_view_priv_by_user 
               WHERE schemaname NOT LIKE 'pg%' 
               AND schemaname <> 'information_schema'
               AND schemaname <> 'admin'
               AND usename <> 'rdsdb' AND usename <>'""" + source.uname + "';")
                # print(q)

                for queries in q.getresult():
                    for query in queries:
                        dqueries.append(query)

                db.close()

                print(
                    "\nClose connection to source and connect to destination")
                print('\nloading')
                if (snsProvided):
                    snsClient.publish(TopicArn=snsArn,
                                      Message="Fetching data from " +
                                      source.ename.split('.')[0] +
                                      " complete,Closing connection with " +
                                      source.ename.split('.')[0])
                # now in destination

                db = DB(dbname=destination.dname,
                        host=destination.ename,
                        port=destination.port,
                        user=destination.uname,
                        passwd=destination.pwd)
                print("\n\n\n\nExecuting Queries in Encrypted cluster:")
                print('======================================')
                db.query("CREATE SCHEMA IF NOT EXISTS admin;")
                urllib.request.urlretrieve(
                    'https://raw.githubusercontent.com/awslabs/amazon-redshift-utils/master/src/AdminViews/v_generate_user_grant_revoke_ddl.sql',
                    'v_generate_user_grant_revoke_ddl.sql')
                filename = 'v_generate_user_grant_revoke_ddl.sql'
                file = open(filename, 'r')

                str = file.read()
                db.query('drop view if exists admin.{0};'.format(
                    filename.split('.')[0]))
                db.query(str)
                os.remove(filename)

                # for u in users:
                #   if destination.uname!=u:
                #     db.query("drop user IF Exists "+u+" ;")
                # for g in groups:
                #   db.query("drop group "+g+" ;")

                if (snsProvided):
                    snsClient.publish(
                        TopicArn=snsArn,
                        Message=
                        "successfully logged in into destination cluster(" +
                        destination.ename.split('.')[0] +
                        ") , executing required queries")

                q = db.query(
                    "SELECT usename  FROM pg_user WHERE usename <> 'rdsdb' and usename<>'"
                    + destination.uname + "';").getresult()
                # print(q)
                users1 = []
                for u in q:
                    for name in u:
                        users1.append(name)

                commonUsers = []
                if (len(users1) > 0):
                    commonUsers = commonList(users, users1)
                # deleting common users.
                if (len(commonUsers) > 0):
                    for user in commonUsers:
                        q = db.query(
                            "select ddl from admin.v_generate_user_grant_revoke_ddl where (grantee='"
                            + user + "' or grantor='" + user +
                            "') and ddltype='revoke' order by ddl;")

                        for x in q.getresult():
                            for query in x:
                                db.query(query)
                        db.query("drop user \"" + user + "\";")

                # delete groups
                groups1 = []
                q = db.query("select groname from pg_group")
                for g in q.getresult():
                    for group in g:
                        groups1.append(group)

                commonGroups = []
                if (len(groups1) > 0):
                    commonGroups = commonList(groups, groups1)

                if (len(commonGroups) > 0):
                    for group in commonGroups:
                        db.query("drop group \"" + group + "\";")

                count = 1

                #print(dqueries)

                # print(dqueries)
                try:

                    for query in dqueries:
                        print("\t\t\t\tQuery No: " + count.__str__() +
                              "\t\t\t\t\t")
                        print('\t\t\t\t-=-=-=-=-=-=-=-\t\t\t\t\t\n')
                        print('\n')
                        print(query)
                        db.query(query)
                        print('\n\n\n')
                        print('===successs====')
                        print('\n')
                        count += 1

                    if (snsProvided):
                        snsClient.publish(
                            TopicArn=snsArn,
                            Message=
                            "executing queries complete ,now copying data from s3: "
                            + s3bucket + " to destination cluster : " +
                            destination.ename.split('.')[0])

                    print("Loading data from s3 into tables")

                    for table in user_tables:
                        if table != 'lineorder':
                            db.query("Copy " + table + " from 's3://" +
                                     s3bucket + "/" + table +
                                     "/part_1manifest' iam_role '" + role +
                                     "' manifest;")

                    if (snsProvided):
                        snsClient.publish(TopicArn=snsArn,
                                          Message="Fetching data s3 completed")

                except Exception as error:
                    #print("Inside The excpet block for internal errors")
                    for file in temp_views:
                        os.remove(file)
                    response = src.modify_cluster_iam_roles(
                        ClusterIdentifier=source.ename.split('.')[0],
                        RemoveIamRoles=[roleArn])

                    response = dst.modify_cluster_iam_roles(
                        ClusterIdentifier=destination.ename.split('.')[0],
                        RemoveIamRoles=[roleArn])

                    while (src.describe_clusters(
                            ClusterIdentifier=source.ename.split('.')
                        [0])['Clusters'][0]['ClusterStatus'] != 'available'
                           or dst.describe_clusters(
                               ClusterIdentifier=destination.ename.split(
                                   '.')[0])['Clusters'][0]['ClusterStatus'] !=
                           'available'):
                        sleep(15)
                        print(
                            'detaching temporary role from Clusters and restoring to previous state'
                        )

                    response = client.detach_role_policy(RoleName=rolename,
                                                         PolicyArn=policyarn)
                    print('detaching access policy from role')

                    response = client.delete_policy(PolicyArn=policyarn)
                    print('deleting access policy')
                    response = client.delete_role(RoleName=rolename)
                    print('deleting role')
                    print(error)
                    db.close()
                    print('connections closed')
                    print('Migration Failure')
                    sys.exit()

                if (snsProvided):
                    snsClient.publish(
                        TopicArn=snsArn,
                        Message=
                        "executing queries complete ,now copying data from s3: "
                        + s3bucket + " to destination cluster : " +
                        destination.ename.split('.')[0])

                print("Loading data from s3 into tables")
                for table in user_tables:
                    if table != 'lineorder':
                        db.query("Copy " + table + " from 's3://" + s3bucket +
                                 "/" + table + "/part_1manifest' iam_role '" +
                                 role + "' manifest;")

                if (snsProvided):
                    snsClient.publish(TopicArn=snsArn,
                                      Message="Fetching data s3 completed")

                c = 0
                countTables = db.query(
                    "SELECT count(*) FROM pg_catalog.pg_tables WHERE schemaname <> 'information_schema' AND schemaname <> 'pg_catalog' ;"
                ).getresult()

                for i in countTables:
                    for j in i:
                        c = j
                db.close()

                print('\nDeleting dependencies')
                print('\ncleanup in process \n')

                for file in temp_views:
                    os.remove(file)

                response = src.modify_cluster_iam_roles(
                    ClusterIdentifier=source.ename.split('.')[0],
                    RemoveIamRoles=[roleArn])

                response = dst.modify_cluster_iam_roles(
                    ClusterIdentifier=destination.ename.split('.')[0],
                    RemoveIamRoles=[roleArn])

                while (src.describe_clusters(
                        ClusterIdentifier=source.ename.split('.')
                    [0])['Clusters'][0]['ClusterStatus'] != 'available'
                       or dst.describe_clusters(
                           ClusterIdentifier=destination.ename.split('.')[0])
                       ['Clusters'][0]['ClusterStatus'] != 'available'):
                    sleep(15)
                    print(
                        'detaching temporary role from Clusters and restoring to previous state'
                    )

                response = client.detach_role_policy(RoleName=rolename,
                                                     PolicyArn=policyarn)
                print('detaching access policy from role')

                response = client.delete_policy(PolicyArn=policyarn)
                print('deleting access policy')
                response = client.delete_role(RoleName=rolename)

                print('deleting role')
                print('dependent files deleted')
                print('Closing open connections')

                if (c == len(tables)):
                    print(
                        "==========================ENCRyPTION ComPleted========================="
                    )
                    print(
                        "\n NOTICE:\n========\nPlease check your data in your new cluster and old cluster and if everything seems to be fine feel free to delete the old cluster after taking a snapshot of it .\n Please also change the password of users from default using alter command.\n As a best practice we suggest you to enable audit logging on +"
                        + destination.ename.split('.')[0] +
                        " cluster if not already enabled .\n Thank you :)")

                    if (snsProvided):
                        snsClient.publish(
                            TopicArn=snsArn,
                            Message="Migration Process Completed")
                else:

                    print("XXXXXXXXXX PROCESS FAILED XXXXXXXXXXXXXXXXX")

                    if (snsProvided):
                        snsClient.publish(TopicArn=snsArn,
                                          Message="Migration Process Failed ")
            else:
                print("S3 Bucket name invalid")

        else:
            if (not srv):

                for file in temp_views:
                    os.remove(file)
                print("Source Endpoint not valid")

            elif (not dsv):

                for file in temp_views:
                    os.remove(file)
                print("Destination Endpoint not valid")

    except Exception as error:
        print(
            "\n\n Please delete IAM role from AWS console (also disassociate if still associated with redshift cluster) "
            + rolename)
        print(error)
        print('Migration Failure')
示例#18
0
from pg import DB
PG=DB(dbname='VM', host='192.168.1.112' ,port=5432,user='******',passwd='123456')
sql="select A.sample_time,A.stat_name,B.stat_rollup,B.unit,A.stat_group,A.entity,A.stat_value from hist_stat_daily A ,stat_counters B where A.stat_id=B.id "
q=PG.query(sql)
rows=q.getresult()
for row in rows:
    PG.insert("vm_state",time=row[0],stat_name=row[1],stat_rollup_type=row[2],unit=row[3],stat_group=row[4],entity=row[5],stat_value=row[6])
print("成功")
PG.close()
示例#19
0
class vol_utils(object):
    def __init__(self):
        self.logger = logging.getLogger('volume_project.sql_utilities')
        self.db_connect()

    def db_connect(self):
        CONFIG = configparser.ConfigParser()
        CONFIG.read('db.cfg')
        dbset = CONFIG['DBSETTINGS']
        self.db = DB(dbname=dbset['database'],
                     host=dbset['host'],
                     user=dbset['user'],
                     passwd=dbset['password'])
        self.logger.info('Database connected.')

    def exec_file(self, filename):
        try:
            f = open(filename)
            exec(filename)
        except:
            for root_f, folders, files in os.walk('.'):
                if filename in files:
                    f = root_f + '/' + filename
                    break
            self.logger.info('Running ', f)
            exec(f)

        if f is None:
            self.logger.error('File %s not found!', filename)
            raise Exception('File %s not found!', filename)

    def execute_sql(self, filename):
        f = None
        try:
            f = open(filename)
        except:
            for root_f, folders, files in os.walk('.'):
                if filename in files:
                    f = open(root_f + '/' + filename)
        if f is None:
            self.logger.error('File %s not found!', filename)
            raise Exception('File not found!')

        sql = f.read()
        reconnect = 0
        while True:
            try:
                self.db.query(sql)
                self.db.commit()
                return
            except ProgrammingError as pe:
                self.logger.error('Error in SQL', exc_info=True)
                self.db_connect()
                reconnect += 1
            if reconnect > 5:
                raise Exception('Check DB connection. Cannot connect')

    def get_sql_results(self,
                        filename,
                        columns,
                        replace_columns=None,
                        parameters=None):
        '''
        Input:
            filename
            columns: a list of column names
            replace_columns: a dictionary of {placeholders:real strings}
            parameters: list of parameter values
        Output:
            dataframe of results
        '''

        f = None
        try:
            f = open(filename)
        except:
            for root_f, folders, files in os.walk('.'):
                if filename in files:
                    f = open(root_f + '/' + filename)

        if f is None:
            if filename[:
                        6] == 'SELECT':  # Also accepts sql queries directly in string form
                sql = filename
            else:
                self.logger.error('File %s not found!', filename)
                raise Exception('File not found!')
        else:
            sql = f.read()

        if replace_columns is not None:
            for key, value in replace_columns.items():
                sql = sql.replace(key, str(value))

        reconnect = 0
        while True:
            try:
                if parameters is not None:
                    return pd.DataFrame(self.db.query(sql,
                                                      parameters).getresult(),
                                        columns=columns)
                else:
                    return pd.DataFrame(self.db.query(sql).getresult(),
                                        columns=columns)
            except ProgrammingError as pe:
                self.logger.error('Error in SQL', exc_info=True)
                self.db_connect()
                reconnect += 1
            if reconnect > 5:
                raise Exception('Check Error Message')

    def load_pkl(self, filename):
        f = None
        try:
            f = open(filename, "rb")
        except:
            for root_f, folders, files in os.walk('.'):
                if filename in files:
                    f = open(root_f + '/' + filename)
        if f is None:
            self.logger.error('File %s not found!', filename)
            raise Exception('File not found!')

        return pickle.load(f)

    def truncatetable(self, tablename):
        reconnect = 0
        while True:
            try:
                self.db.truncate(tablename)
                self.db.commit()
                self.logger.info('%s truncated', tablename)
                return
            except ProgrammingError as pe:
                print(pe)
                self.db_connect()
                reconnect += 1
            if reconnect > 5:
                self.logger.error('Error in SQL', exc_info=True)
                raise Exception('Check Error Message')

    def inserttable(self, tablename, content):
        reconnect = 0
        while True:
            try:
                self.db.inserttable(tablename, content)
                self.db.commit()
                self.logger.info('Inserted table: %s', tablename)
                break
            except ProgrammingError:
                self.db_connect()
                reconnect += 1
            if reconnect > 5:
                self.logger.error('Error in SQL', exc_info=True)
                raise Exception('Check Error Message')

    def __exit__(self):
        self.db.close()
示例#20
0
MS = DBManage.DBManage(host="192.168.1.200",
                       port="1433",
                       user="******",
                       password="******",
                       database="VIM_VCDB",
                       charset="UTF-8")
PG = DB(dbname='VM', host='localhost', port=5432, user='******', passwd='123456')
MS.getMSSConnetion()

sql1 = "SELECT [ID] ,[STAT_ROLLUP] ,[NAME]  ,[GROUP_NAME]  ,[TYPE]  ,[UNIT]  ,[ASSOCIATE_IDS]   ,[STATS_LEVEL]    ,[FIXED_COLLECTION_INTERVAL]  FROM [VIM_VCDB].[dbo].[VPXV_STAT_COUNTERS]"

cursor1 = MS.QuerySql(sql1)
Tuple = cursor1.fetchone()
while Tuple:
    PG.insert('stat_counters',
              id=Tuple[0],
              stat_rollup=Tuple[1],
              name=Tuple[2],
              group_name=Tuple[3],
              type=Tuple[4],
              unit=Tuple[5],
              associate_ids=Tuple[6],
              stats_level=Tuple[7],
              fixed_collection_interval=Tuple[8])
    Tuple = cursor1.fetchone()

print("成功!")
MS.closeConn()
PG.close()
#!/usr/bin/env python

import sys
import yaml
import json
from pg import DB

bosh_manifest_path = sys.argv[1]
with open(bosh_manifest_path, 'r') as f:
    bosh_manifest = yaml.load(f.read())

settings_path = sys.argv[2]
with open(settings_path, 'r') as f:
    cf_ip = json.loads(f.read())['cf-ip']

postgres_properties = bosh_manifest['jobs'][0]['properties']['postgres']
dbname = postgres_properties.get('database')
host = postgres_properties.get('host')
port = postgres_properties.get('port', 5432)
user = postgres_properties.get('user')
passwd = postgres_properties.get('password')

db = DB(dbname=dbname, host=host, port=port, user=user, passwd=passwd)

domain_id = db.insert('domains', name='xip.io', type='NATIVE')['id']
db.insert('records', domain_id=domain_id, name='{0}.xip.io'.format(cf_ip), content='localhost [email protected] 1', type='SOA', ttl=86400, prio=None)
db.insert('records', domain_id=domain_id, name='{0}.xip.io'.format(cf_ip), content='dns-us1.powerdns.net', type='NS', ttl=86400, prio=None)
db.insert('records', domain_id=domain_id, name='*.{0}.xip.io'.format(cf_ip), content=cf_ip, type='A', ttl=120, prio=None)

db.close()
示例#22
0
文件: sassy_bot.py 项目: pndaly/SASSy
def sassy_bot_read(_radius=RADIUS,
                   _begin=BEGIN_ISO,
                   _end=END_ISO,
                   _rb_min=RB_MIN,
                   _rb_max=RB_MAX,
                   _logger=None):

    # check input(s)
    _radius = _radius / 3600.0 if (isinstance(_radius, float)
                                   and 0.0 <= _radius) else RADIUS / 3600.0
    _begin_iso = _begin if (
        re.match(DATE_RULE, _begin) is not None
        and iso_to_jd(_begin) is not float(math.nan)) else BEGIN_ISO
    _end_iso = _end if (re.match(DATE_RULE, _end) is not None and
                        iso_to_jd(_end) is not float(math.nan)) else END_ISO
    _rb_min = _rb_min if (isinstance(_rb_min, float)
                          and 0.0 <= _rb_min <= 1.0) else RB_MIN
    _rb_max = _rb_max if (isinstance(_rb_max, float)
                          and 0.0 <= _rb_max <= 1.0) else RB_MAX

    # entry
    if _logger:
        _logger.info(f"_radius = {_radius}")
        _logger.info(f"_begin_iso = {_begin_iso}")
        _logger.info(f"_end_iso = {_end_iso}")
        _logger.info(f"_rb_min = {_rb_min}")
        _logger.info(f"_rb_max = {_rb_max}")

    # set default(s)
    _alerce = Alerce(_logger)
    _res = None
    _results = []
    _begin_jd = iso_to_jd(_begin_iso)
    _end_jd = iso_to_jd(_end_iso)
    if _logger:
        _logger.info(f"_begin_jd = {_begin_jd}")
        _logger.info(f"_end_jd = {_end_jd}")

    # connect to database
    if _logger:
        _logger.info(f"Connecting to database")
    try:
        db = DB(dbname=SASSY_DB_NAME,
                host=SASSY_DB_HOST,
                port=int(SASSY_DB_PORT),
                user=SASSY_DB_USER,
                passwd=SASSY_DB_PASS)
    except Exception as e:
        if _logger:
            _logger.error(f"Failed connecting to database, error={e}")
        return
    if _logger:
        _logger.info(f"Connected to database OK")

    # drop any existing view
    _cmd_drop = 'DROP VIEW IF EXISTS sassy_bot;'
    if _logger:
        _logger.info(f'Executing {_cmd_drop}')
    try:
        db.query(_cmd_drop)
    except Exception as e:
        if _logger:
            _logger.error(f'Failed to execute {_cmd_drop}, e={e}')
        if db is not None:
            db.close()
        return
    if _logger:
        _logger.info(f'Executed {_cmd_drop} OK')

    # create new view
    _cmd_view = f'CREATE OR REPLACE VIEW sassy_bot ("objectId", jd, drb, rb, sid, candid, ssnamenr, ra, dec) ' \
                f'AS WITH e AS (SELECT "objectId", jd, rb, drb, id, candid, ssnamenr, ' \
                f'(CASE WHEN ST_X(ST_AsText(location)) < 0.0 THEN ST_X(ST_AsText(location))+360.0 ELSE ' \
                f'ST_X(ST_AsText(location)) END), ST_Y(ST_AsText(location)) FROM alert WHERE ' \
                f'(("objectId" LIKE \'%ZTF2%\') AND (jd BETWEEN {_begin_jd} AND {_end_jd}) AND ' \
                f'((rb BETWEEN {_rb_min} AND {_rb_max}) OR (drb BETWEEN {_rb_min} AND {_rb_max})))) SELECT * FROM e;'
    if _logger:
        _logger.info(f'Executing {_cmd_view}')
    try:
        db.query(_cmd_view)
    except Exception as e:
        if _logger:
            _logger.error(f'Failed to execute {_cmd_view}, e={e}')
    if _logger:
        _logger.info(f'Executed {_cmd_view} OK')

    # select
    _cmd_select = f"WITH x AS (SELECT * FROM sassy_bot), y AS (SELECT x.*, " \
                  f"(g.id, g.ra, g.dec, g.z, g.dist, q3c_dist(x.ra, x.dec, g.ra, g.dec)) " \
                  f"FROM x, glade_q3c AS g WHERE q3c_join(x.ra, x.dec, g.ra, g.dec, {_radius:.5f})), z AS " \
                  f"(SELECT * FROM y LEFT OUTER JOIN tns_q3c AS t ON " \
                  f"q3c_join(y.ra, y.dec, t.ra, t.dec, {_radius:.5f})) SELECT * FROM z WHERE tns_id IS null;"
    if _logger:
        _logger.info(f'Executing {_cmd_select}')
    try:
        _res = db.query(_cmd_select)
    except Exception as e:
        if _logger:
            _logger.error(f'Failed to execute {_cmd_select}, e={e}')
    if _logger:
        _logger.info(f'Executed {_cmd_select} OK')

    # close and exit
    if db is not None:
        db.close()

    # create output(s)
    for _e in _res:
        if _logger:
            _logger.info(f'_e={_e}')
        _gid, _gra, _gdec, _gz, _gdist, _gdelta = _e[9][1:-1].split(",")
        _d = {"objectId": f"{_e[0]}"}

        try:
            _classifier = _alerce.get_classifier(oid=_d['objectId'],
                                                 classifier='early')
            _d['early_classifier'] = _classifier[1]
            _d['early_percent'] = _classifier[2] * 100.0
        except Exception:
            _d['early_classifier'] = 'n/a'
            _d['early_percent'] = math.nan

        try:
            _classifier = _alerce.get_classifier(oid=_d['objectId'],
                                                 classifier='late')
            _d['late_classifier'] = _classifier[1]
            _d['late_percent'] = _classifier[2] * 100.0
        except Exception:
            _d['late_classifier'] = 'n/a'
            _d['late_percent'] = math.nan

        try:
            _d["jd"] = float(f"{_e[1]}")
        except Exception:
            _d["jd"] = float(math.nan)
        try:
            _d["drb"] = float(f"{_e[2]}")
        except Exception:
            _d["drb"] = float(math.nan)
        try:
            _d["rb"] = float(f"{_e[3]}")
        except Exception:
            _d["rb"] = float(math.nan)
        try:
            _d["sid"] = int(f"{_e[4]}")
        except Exception:
            _d["sid"] = -1
        try:
            _d["candid"] = int(f"{_e[5]}")
        except Exception:
            _d["candid"] = -1
        try:
            _d["ssnamenr"] = '' if f"{_e[6]}".lower() == 'null' else f"{_e[6]}"
        except Exception:
            _d["ssnamenr"] = ''
        try:
            _d["RA"] = float(f"{_e[7]}")
        except Exception:
            _d["RA"] = float(math.nan)
        try:
            _d["Dec"] = float(f"{_e[8]}")
        except Exception:
            _d["Dec"] = float(math.nan)
        try:
            _d["gid"] = int(f"{_gid}")
        except Exception:
            _d["gid"] = -1
        try:
            _d["gRA"] = float(f"{_gra}")
        except Exception:
            _d["gRA"] = float(math.nan)
        try:
            _d["gDec"] = float(f"{_gdec}")
        except Exception:
            _d["gDec"] = float(math.nan)
        try:
            _d["gDist"] = float(f"{_gdist}")
        except Exception:
            _d["gDist"] = float(math.nan)
        try:
            _d["gRedshift"] = float(f"{_gz}")
        except Exception:
            _d["gRedshift"] = float(math.nan)
        try:
            _d["gDelta"] = float(f"{_gdelta}") * 3600.0
        except Exception:
            _d["gDelta"] = float(math.nan)

        try:
            _d["file"] = get_avro_filename(_d["jd"], _d["candid"])
        except Exception:
            _d["file"] = ""

        try:
            _d["png"] = avro_plot(_d["file"], True)[0]
        except Exception:
            _d["png"] = ""

        if _logger:
            _logger.info(f'_d={_d}')

        try:
            if _d["ssnamenr"] != '':
                if _logger:
                    _logger.debug(f'ignoring solar system object, _d={_d}')
            else:
                _results.append(_d)
        except Exception:
            continue

    # return
    return _results
示例#23
0
            root.remove(current)
            to_visit.extend(list(pairs.groupby('c1').get_group(current)['c2']))
            visited.append(current)

    chains.append(chain)

groups = {}
count = 1
table = []
for group in chains:
    for tcl in group:
        table.append([tcl, count])
    count = count + 1

db.truncate('prj_volume.centreline_groups_l2')
db.inserttable('prj_volume.centreline_groups_l2', table)

group_no_merge = [
    x for t in db.query(
        'SELECT DISTINCT group_number FROM prj_volume.centreline_groups LEFT JOIN prj_volume.centreline_groups_l2 ON (group_number=l1_group_number) WHERE l2_group_number IS NULL'
    ).getresult() for x in t
]

for tcl in group_no_merge:
    table.append([tcl, count])
    count = count + 1

db.truncate('prj_volume.centreline_groups_l2')
db.inserttable('prj_volume.centreline_groups_l2', table)
db.close()
示例#24
0
class pgcon(object):
    def __init__(self):
        self.pgsql = {}
        self.result = None
        self.status = "Unknown"
        self.db = None

    def query_consul(self):
        '''
        at first try to get the consul coordinates from the environment. these 
        should be set by docker due to container linking. then fetch postgresql 
        container coordinates.
        '''

        try:
            g.consul_server = os.environ['CONSUL_PORT_8500_TCP_ADDR']
            g.consul_port = os.environ['CONSUL_PORT_8500_TCP_PORT']
        except:
            if app.debug:
                g.consul_server = '172.17.0.2'
                g.consul_port = '8500'
            else:
                raise EnvironmentError(
                    'No consul environment variables available')

        try:
            c = consul.Consul(host=g.consul_server, port=g.consul_port)
            cresponse = c.kv.get('postgresql', recurse=True)[1]
        except:
            raise LookupError('Error in connecting to the Consul server')

        try:
            for d in cresponse:
                v = d['Value']
                k = d['Key'].split('/')[-1]
                self.pgsql[k] = v
        except:
            raise AttributeError('Something is wrong with Consuls response')

    def connect(self):
        if not self.pgsql:
            raise ValueError('No coordinates to connect to the db.')
        try:
            self.db = DB(dbname=self.pgsql['user'],
                         host=self.pgsql['host'],
                         port=int(self.pgsql['port']),
                         user=self.pgsql['user'],
                         passwd=self.pgsql['password'])
        except:
            raise IOError('Could not connect to the db.')
        self.status = "Connected"

    def query(self, p=None):
        '''
        connect to the db and retrieve something
        '''

        if not self.db:
            self.connect()

        if not self.db:
            self.result = {'postgresql db: ': 'not connected'}
            return

        if p:
            self.result = self.db.query(p)

    def call(self, m):
        if not self.db:
            self.connect()
        self.result = getattr(self.db, m)()

    def disconnect(self):
        try:
            self.db.close()
        except:
            pass
示例#25
0
class SteadyPgConnection:
	"""Class representing steady connections to a PostgreSQL database.

	Underlying the connection is a classic PyGreSQL pg API database
	connection which is reset if the connection is lost or used too often.
	Thus the resulting connection is steadier ("tough and self-healing").

	If you want the connection to be persistent in a threaded environment,
	then you should not deal with this class directly, but use either the
	PooledPg module or the PersistentPg module to get the connections.

	"""

	def __init__(self, maxusage=0, setsession=None, *args, **kwargs):
		"""Create a "tough" PostgreSQL connection.

		maxusage: maximum usage limit for the underlying PygreSQL connection
			(number of uses, 0 or False means unlimited usage)
			When this limit is reached, the connection is automatically reset.
		setsession: optional list of SQL commands that may serve to prepare
			the session, e.g. ["set datestyle to ...", "set time zone ..."]
		args, kwargs: the parameters that shall be used to establish
			the PostgreSQL connections with PyGreSQL using pg.DB()

		"""
		self._maxusage = maxusage
		self._setsession_sql = setsession
		self._closeable = 1
		self._usage = 0
		self._con = PgConnection(*args, **kwargs)
		self._setsession()

	def _setsession(self):
		"""Execute the SQL commands for session preparation."""
		if self._setsession_sql:
			for sql in self._setsession_sql:
				self._con.query(sql)

	def _close(self):
		"""Close the tough connection.

		You can always close a tough connection with this method
		and it will not complain if you close it more than once.

		"""
		try:
			self._con.close()
			self._usage = 0
		except:
			pass

	def close(self):
		"""Close the tough connection.

		You are allowed to close a tough connection by default
		and it will not complain if you close it more than once.

		You can disallow closing connections by setting
		the _closeable attribute to 0 or False. In this case,
		closing a connection will be silently ignored.

		"""
		if self._closeable:
			self._close()

	def reopen(self):
		"""Reopen the tough connection.

		It will not complain if the connection cannot be reopened."""
		try:
			self._con.reopen()
			self._setsession()
			self._usage = 0
		except:
			pass

	def reset(self):
		"""Reset the tough connection.

		If a reset is not possible, tries to reopen the connection.
		It will not complain if the connection is already closed.

		"""
		try:
			self._con.reset()
			self._setsession()
			self._usage = 0
		except:
			self.reopen()

	def _get_tough_method(self, method):
		"""Return a "tough" version of a connection class method.

		The tough version checks whether the connection is bad (lost)
		and automatically and transparently tries to reset the connection
		if this is the case (for instance, the database has been restarted).

		"""
		def tough_method(*args, **kwargs):
			try: # check whether connection status is bad
				if not self._con.db.status:
					raise AttributeError
				if self._maxusage: # or connection used too often
					if self._usage >= self._maxusage:
						raise AttributeError
			except:
				self.reset() # then reset the connection
			try:
				r = method(*args, **kwargs) # try connection method
			except: # error in query
				if self._con.db.status: # if it was not a connection problem
					raise # then propagate the error
				else: # otherwise
					self.reset() # reset the connection
					r = method(*args, **kwargs) # and try one more time
			self._usage += 1
			return r
		return tough_method

	def __getattr__(self, name):
		"""Inherit the members of the standard connection class.

		Some methods are made "tougher" than in the standard version.

		"""
		attr = getattr(self._con, name)
		if name in ('query', 'get', 'insert', 'update', 'delete') \
			or name.startswith('get_'):
			attr = self._get_tough_method(attr)
		return attr