def _open(self): """ DO NOT USE THIS UNLESS YOU close() FIRST""" if self.settings.host.startswith("mysql://"): # DECODE THE URI: mysql://username:password@host:optional_port/database_name up = strings.between(self.settings.host, "mysql://", "@") if ":" in up: self.settings.username, self.settings.password = unquote( up).split(":") else: self.settings.username = up url = strings.between(self.settings.host, "@", None) hp, self.settings.schema = url.split("/", 1) if ":" in hp: self.settings.host, self.settings.port = hp.split(":") self.settings.port = int(self.settings.port) else: self.settings.host = hp # SSL PEM if self.settings.host in ("localhost", "mysql", '127.0.0.1'): ssl_context = None else: if self.settings.ssl and not self.settings.ssl.pem: Log.error("Expecting 'pem' property in ssl") # ssl_context = ssl.create_default_context(**get_ssl_pem_file(self.settings.ssl.pem)) filename = File(".pem") / URL(self.settings.ssl.pem).host filename.write_bytes(http.get(self.settings.ssl.pem).content) ssl_context = {"ca": filename.abspath} try: self.db = connect( host=self.settings.host, port=self.settings.port, user=coalesce(self.settings.username, self.settings.user), passwd=coalesce(self.settings.password, self.settings.passwd), db=coalesce(self.settings.schema, self.settings.db), read_timeout=coalesce(self.settings.read_timeout, (EXECUTE_TIMEOUT / 1000) - 10 if EXECUTE_TIMEOUT else None, 5 * 60), charset=u"utf8", use_unicode=True, ssl=ssl_context, cursorclass=cursors.SSCursor) except Exception as e: if self.settings.host.find("://") == -1: Log.error(u"Failure to connect to {{host}}:{{port}}", host=self.settings.host, port=self.settings.port, cause=e) else: Log.error( u"Failure to connect. PROTOCOL PREFIX IS PROBABLY BAD", e) self.cursor = None self.partial_rollback = False self.transaction_level = 0 self.backlog = [ ] # accumulate the write commands so they are sent at once if self.readonly: self.begin()
class MySQL(object): """ Parameterize SQL by name rather than by position. Return records as objects rather than tuples. """ @override def __init__(self, host, username, password, port=3306, debug=False, schema=None, preamble=None, readonly=False, kwargs=None): """ OVERRIDE THE settings.schema WITH THE schema PARAMETER preamble WILL BE USED TO ADD COMMENTS TO THE BEGINNING OF ALL SQL THE INTENT IS TO HELP ADMINISTRATORS ID THE SQL RUNNING ON THE DATABASE schema - NAME OF DEFAULT database/schema IN QUERIES preamble - A COMMENT TO BE ADDED TO EVERY SQL STATEMENT SENT readonly - USED ONLY TO INDICATE IF A TRANSACTION WILL BE OPENED UPON USE IN with CLAUSE, YOU CAN STILL SEND UPDATES, BUT MUST OPEN A TRANSACTION BEFORE YOU DO """ all_db.append(self) self.settings = kwargs self.cursor = None self.query_cursor = None if preamble == None: self.preamble = "" else: self.preamble = indent(preamble, "# ").strip() + "\n" self.readonly = readonly self.debug = coalesce(debug, DEBUG) if host: self._open() def _open(self): """ DO NOT USE THIS UNLESS YOU close() FIRST""" if self.settings.ssl.ca.startswith("https://"): self.pemfile_url = self.settings.ssl.ca self.pemfile = File("./resources/pem") / self.settings.host self.pemfile.write_bytes(http.get(self.pemfile_url).content) self.settings.ssl.ca = self.pemfile.abspath try: self.db = connect( host=self.settings.host, port=self.settings.port, user=coalesce(self.settings.username, self.settings.user), passwd=coalesce(self.settings.password, self.settings.passwd), db=coalesce(self.settings.schema, self.settings.db), read_timeout=coalesce(self.settings.read_timeout, (EXECUTE_TIMEOUT / 1000) - 10 if EXECUTE_TIMEOUT else None, 5 * 60), charset=u"utf8", use_unicode=True, ssl=coalesce(self.settings.ssl, None), cursorclass=cursors.SSCursor) except Exception as e: if self.settings.host.find("://") == -1: Log.error(u"Failure to connect to {{host}}:{{port}}", host=self.settings.host, port=self.settings.port, cause=e) else: Log.error( u"Failure to connect. PROTOCOL PREFIX IS PROBABLY BAD", e) self.cursor = None self.partial_rollback = False self.transaction_level = 0 self.backlog = [ ] # accumulate the write commands so they are sent at once if self.readonly: self.begin() def __enter__(self): if not self.readonly: self.begin() return self def __exit__(self, type, value, traceback): if self.readonly: self.close() return if isinstance(value, BaseException): try: if self.cursor: self.cursor.close() self.cursor = None self.rollback() except Exception as e: Log.warning(u"can not rollback()", cause=[value, e]) finally: self.close() return try: self.commit() except Exception as e: Log.warning(u"can not commit()", e) finally: self.close() def transaction(self): """ return not-started transaction (for with statement) """ return Transaction(self) def begin(self): if self.transaction_level == 0: self.cursor = self.db.cursor() self.transaction_level += 1 self.execute("SET TIME_ZONE='+00:00'") if EXECUTE_TIMEOUT: try: self.execute("SET MAX_EXECUTION_TIME=" + text(EXECUTE_TIMEOUT)) self._execute_backlog() except Exception as e: e = Except.wrap(e) if "Unknown system variable 'MAX_EXECUTION_TIME'" in e: globals( )['EXECUTE_TIMEOUT'] = 0 # THIS VERSION OF MYSQL DOES NOT HAVE SESSION LEVEL VARIABLE else: raise e def close(self): if self.transaction_level > 0: if self.readonly: self.commit() # AUTO-COMMIT else: Log.error("expecting commit() or rollback() before close") self.cursor = None # NOT NEEDED try: self.db.close() except Exception as e: e = Except.wrap(e) if "Already closed" in e: return Log.warning("can not close()", e) finally: try: all_db.remove(self) except Exception as e: Log.error("not expected", cause=e) def commit(self): try: self._execute_backlog() except Exception as e: with suppress_exception: self.rollback() Log.error("Error while processing backlog", e) if self.transaction_level == 0: Log.error("No transaction has begun") elif self.transaction_level == 1: if self.partial_rollback: with suppress_exception: self.rollback() Log.error("Commit after nested rollback is not allowed") else: if self.cursor: self.cursor.close() self.cursor = None self.db.commit() self.transaction_level -= 1 def flush(self): try: self.commit() except Exception as e: Log.error("Can not flush", e) try: self.begin() except Exception as e: Log.error("Can not flush", e) def rollback(self): self.backlog = [] # YAY! FREE! if self.transaction_level == 0: Log.error("No transaction has begun") elif self.transaction_level == 1: self.transaction_level -= 1 if self.cursor != None: self.cursor.close() self.cursor = None self.db.rollback() else: self.transaction_level -= 1 self.partial_rollback = True Log.warning("Can not perform partial rollback!") def call(self, proc_name, params): self._execute_backlog() params = [unwrap(v) for v in params] try: self.cursor.callproc(proc_name, params) self.cursor.close() self.cursor = self.db.cursor() except Exception as e: Log.error("Problem calling procedure " + proc_name, e) def query(self, sql, param=None, stream=False, row_tuples=False): """ RETURN LIST OF dicts """ if not self.cursor: # ALLOW NON-TRANSACTIONAL READS Log.error("must perform all queries inside a transaction") self._execute_backlog() try: if param: sql = expand_template(sql, quote_param(param)) sql = self.preamble + outdent(sql) self.debug and Log.note("Execute SQL:\n{{sql}}", sql=indent(sql)) self.cursor.execute(sql) if row_tuples: if stream: result = self.cursor else: result = wrap(list(self.cursor)) else: columns = [ utf8_to_unicode(d[0]) for d in coalesce(self.cursor.description, []) ] if stream: result = (wrap( {c: utf8_to_unicode(v) for c, v in zip(columns, row)}) for row in self.cursor) else: result = wrap( [{c: utf8_to_unicode(v) for c, v in zip(columns, row)} for row in self.cursor]) return result except Exception as e: e = Except.wrap(e) if "InterfaceError" in e: Log.error("Did you close the db connection?", e) Log.error("Problem executing SQL:\n{{sql|indent}}", sql=sql, cause=e, stack_depth=1) def column_query(self, sql, param=None): """ RETURN RESULTS IN [column][row_num] GRID """ self._execute_backlog() try: old_cursor = self.cursor if not old_cursor: # ALLOW NON-TRANSACTIONAL READS self.cursor = self.db.cursor() self.cursor.execute("SET TIME_ZONE='+00:00'") self.cursor.close() self.cursor = self.db.cursor() if param: sql = expand_template(sql, quote_param(param)) sql = self.preamble + outdent(sql) self.debug and Log.note("Execute SQL:\n{{sql}}", sql=indent(sql)) self.cursor.execute(sql) grid = [[utf8_to_unicode(c) for c in row] for row in self.cursor] # columns = [utf8_to_unicode(d[0]) for d in coalesce(self.cursor.description, [])] result = transpose(*grid) if not old_cursor: # CLEANUP AFTER NON-TRANSACTIONAL READS self.cursor.close() self.cursor = None return result except Exception as e: if isinstance( e, InterfaceError) or e.message.find("InterfaceError") >= 0: Log.error("Did you close the db connection?", e) Log.error("Problem executing SQL:\n{{sql|indent}}", sql=sql, cause=e, stack_depth=1) # EXECUTE GIVEN METHOD FOR ALL ROWS RETURNED def forall(self, sql, param=None, _execute=None): assert _execute num = 0 self._execute_backlog() try: old_cursor = self.cursor if not old_cursor: # ALLOW NON-TRANSACTIONAL READS self.cursor = self.db.cursor() if param: sql = expand_template(sql, quote_param(param)) sql = self.preamble + outdent(sql) self.debug and Log.note("Execute SQL:\n{{sql}}", sql=indent(sql)) self.cursor.execute(sql) columns = tuple( [utf8_to_unicode(d[0]) for d in self.cursor.description]) for r in self.cursor: num += 1 _execute( wrap(dict(zip(columns, [utf8_to_unicode(c) for c in r])))) if not old_cursor: # CLEANUP AFTER NON-TRANSACTIONAL READS self.cursor.close() self.cursor = None except Exception as e: Log.error("Problem executing SQL:\n{{sql|indent}}", sql=sql, cause=e, stack_depth=1) return num def execute(self, sql, param=None): if self.transaction_level == 0: Log.error( "Expecting transaction to be started before issuing queries") if param: sql = expand_template(text(sql), quote_param(param)) sql = outdent(sql) self.backlog.append(sql) if self.debug or len(self.backlog) >= MAX_BATCH_SIZE: self._execute_backlog() def _execute_backlog(self): if not self.backlog: return backlog, self.backlog = self.backlog, [] for i, g in jx.chunk(backlog, size=MAX_BATCH_SIZE): sql = self.preamble + ";\n".join(g) try: self.debug and Log.note( "Execute block of SQL:\n{{sql|indent}}", sql=sql) self.cursor.execute(sql) self.cursor.close() self.cursor = self.db.cursor() except Exception as e: Log.error("Problem executing SQL:\n{{sql|indent}}", sql=sql, cause=e, stack_depth=1) ## Insert dictionary of values into table def insert(self, table_name, record): keys = list(record.keys()) try: command = (SQL_INSERT + quote_column(table_name) + sql_iso(sql_list([quote_column(k) for k in keys])) + SQL_VALUES + sql_iso(sql_list([quote_value(record[k]) for k in keys]))) self.execute(command) except Exception as e: Log.error("problem with record: {{record}}", record=record, cause=e) # candidate_key IS LIST OF COLUMNS THAT CAN BE USED AS UID (USUALLY PRIMARY KEY) # ONLY INSERT IF THE candidate_key DOES NOT EXIST YET def insert_new(self, table_name, candidate_key, new_record): candidate_key = listwrap(candidate_key) condition = sql_eq(**{k: new_record[k] for k in candidate_key}) command = ( SQL_INSERT + quote_column(table_name) + sql_iso(sql_list(quote_column(k) for k in new_record.keys())) + SQL_SELECT + "a.*" + SQL_FROM + sql_iso(SQL_SELECT + sql_list([ quote_value(v) + " " + quote_column(k) for k, v in new_record.items() ]) + SQL_FROM + "DUAL") + " a" + SQL_LEFT_JOIN + sql_iso(SQL_SELECT + "'dummy' exist " + SQL_FROM + quote_column(table_name) + SQL_WHERE + condition + SQL_LIMIT + SQL_ONE) + " b ON " + SQL_TRUE + SQL_WHERE + " exist " + SQL_IS_NULL) self.execute(command, {}) # ONLY INSERT IF THE candidate_key DOES NOT EXIST YET def insert_newlist(self, table_name, candidate_key, new_records): for r in new_records: self.insert_new(table_name, candidate_key, r) def insert_list(self, table_name, records): if not records: return keys = set() for r in records: keys |= set(r.keys()) keys = jx.sort(keys) try: command = (SQL_INSERT + quote_column(table_name) + sql_iso(sql_list([quote_column(k) for k in keys])) + SQL_VALUES + sql_list( sql_iso(sql_list([quote_value(r[k]) for k in keys])) for r in records)) self.execute(command) except Exception as e: Log.error("problem with record: {{record}}", record=records, cause=e) def update(self, table_name, where_slice, new_values): """ where_slice - A Data WHICH WILL BE USED TO MATCH ALL IN table eg {"id": 42} new_values - A dict WITH COLUMN NAME, COLUMN VALUE PAIRS TO SET """ new_values = quote_param(new_values) where_clause = sql_eq(**where_slice) command = (SQL_UPDATE + quote_column(table_name) + SQL_SET + sql_list( [quote_column(k) + "=" + v for k, v in new_values.items()]) + SQL_WHERE + where_clause) self.execute(command, {}) def sort2sqlorderby(self, sort): sort = jx.normalize_sort_parameters(sort) return sql_list([ quote_column(s.field) + (SQL_DESC if s.sort == -1 else SQL_ASC) for s in sort ])
def get_ssl_pem_file(url): filename = File(".pem") / URL(url).host filename.write_bytes(http.get(url).content) return {"cafile": filename.abspath}