class FilamentManager(object): DIALECT_SQLITE = "sqlite" DIALECT_POSTGRESQL = "postgresql" def __init__(self, config): if not set(("uri", "name", "user", "password")).issubset(config): raise ValueError("Incomplete config dictionary") # QUESTION thread local connection (pool) vs sharing a serialized connection, pro/cons? # from sqlalchemy.orm import sessionmaker, scoped_session # Session = scoped_session(sessionmaker(bind=engine)) # when using a connection pool how do we prevent notifiying ourself on database changes? self.lock = Lock() self.notify = None uri_parts = urisplit(config["uri"]) if self.DIALECT_SQLITE == uri_parts.scheme: self.engine = create_engine( config["uri"], connect_args={"check_same_thread": False}) self.conn = self.engine.connect() self.conn.execute( text("PRAGMA foreign_keys = ON").execution_options( autocommit=True)) elif self.DIALECT_POSTGRESQL == uri_parts.scheme: uri = uricompose(scheme=uri_parts.scheme, host=uri_parts.host, port=uri_parts.getport(default=5432), path="/{}".format(config["name"]), userinfo="{}:{}".format(config["user"], config["password"])) self.engine = create_engine(uri) self.conn = self.engine.connect() self.notify = PGNotify(uri) else: raise ValueError("Engine '{engine}' not supported".format( engine=uri_parts.scheme)) def close(self): self.conn.close() def initialize(self): metadata = MetaData() self.profiles = Table( "profiles", metadata, Column("id", INTEGER, primary_key=True, autoincrement=True), Column("vendor", VARCHAR(255), nullable=False, server_default=""), Column("material", VARCHAR(255), nullable=False, server_default=""), Column("density", REAL, nullable=False, server_default="0"), Column("diameter", REAL, nullable=False, server_default="0")) self.spools = Table( "spools", metadata, Column("id", INTEGER, primary_key=True, autoincrement=True), Column("profile_id", INTEGER, nullable=False), Column("name", VARCHAR(255), nullable=False, server_default=""), Column("cost", REAL, nullable=False, server_default="0"), Column("weight", REAL, nullable=False, server_default="0"), Column("used", REAL, nullable=False, server_default="0"), Column("temp_offset", INTEGER, nullable=False, server_default="0"), ForeignKeyConstraint(["profile_id"], ["profiles.id"], ondelete="RESTRICT")) self.selections = Table( "selections", metadata, Column( "tool", INTEGER, ), Column("client_id", VARCHAR(36)), Column("spool_id", INTEGER), PrimaryKeyConstraint("tool", "client_id", name="selections_pkey"), ForeignKeyConstraint(["spool_id"], ["spools.id"], ondelete="CASCADE")) self.versioning = Table( "versioning", metadata, Column("schema_id", INTEGER, primary_key=True, autoincrement=False)) self.modifications = Table( "modifications", metadata, Column("table_name", VARCHAR(255), nullable=False, primary_key=True), Column("action", VARCHAR(255), nullable=False), Column("changed_at", TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP"))) if self.DIALECT_POSTGRESQL == self.engine.dialect.name: def should_create_function(name): row = self.conn.execute( "select proname from pg_proc where proname = '%s'" % name).scalar() return not bool(row) def should_create_trigger(name): row = self.conn.execute( "select tgname from pg_trigger where tgname = '%s'" % name).scalar() return not bool(row) trigger_function = DDL(""" CREATE FUNCTION update_lastmodified() RETURNS TRIGGER AS $func$ BEGIN INSERT INTO modifications (table_name, action, changed_at) VALUES(TG_TABLE_NAME, TG_OP, CURRENT_TIMESTAMP) ON CONFLICT (table_name) DO UPDATE SET action=TG_OP, changed_at=CURRENT_TIMESTAMP WHERE modifications.table_name=TG_TABLE_NAME; PERFORM pg_notify(TG_TABLE_NAME, TG_OP); RETURN NULL; END; $func$ LANGUAGE plpgsql; """) if should_create_function("update_lastmodified"): event.listen(metadata, "after_create", trigger_function) for table in [self.profiles.name, self.spools.name]: for action in ["INSERT", "UPDATE", "DELETE"]: name = "{table}_on_{action}".format(table=table, action=action.lower()) trigger = DDL(""" CREATE TRIGGER {name} AFTER {action} on {table} FOR EACH ROW EXECUTE PROCEDURE update_lastmodified() """.format(name=name, table=table, action=action)) if should_create_trigger(name): event.listen(metadata, "after_create", trigger) elif self.DIALECT_SQLITE == self.engine.dialect.name: for table in [self.profiles.name, self.spools.name]: for action in ["INSERT", "UPDATE", "DELETE"]: name = "{table}_on_{action}".format(table=table, action=action.lower()) trigger = DDL(""" CREATE TRIGGER IF NOT EXISTS {name} AFTER {action} on {table} FOR EACH ROW BEGIN REPLACE INTO modifications (table_name, action) VALUES ('{table}','{action}'); END """.format(name=name, table=table, action=action)) event.listen(metadata, "after_create", trigger) metadata.create_all(self.conn, checkfirst=True) def execute_script(self, script): with self.lock, self.conn.begin(): for stmt in script.split(";"): self.conn.execute(text(stmt)) # versioning def get_schema_version(self): with self.lock, self.conn.begin(): return self.conn.execute( select([func.max(self.versioning.c.schema_id)])).scalar() def set_schema_version(self, version): with self.lock, self.conn.begin(): self.conn.execute(insert(self.versioning).values((version, ))) self.conn.execute( delete(self.versioning).where( self.versioning.c.schema_id < version)) # profiles def get_all_profiles(self): with self.lock, self.conn.begin(): stmt = select([self.profiles]).order_by(self.profiles.c.material, self.profiles.c.vendor) result = self.conn.execute(stmt) return self._result_to_dict(result) def get_profiles_lastmodified(self): with self.lock, self.conn.begin(): stmt = select([ self.modifications.c.changed_at ]).where(self.modifications.c.table_name == "profiles") return self.conn.execute(stmt).scalar() def get_profile(self, identifier): with self.lock, self.conn.begin(): stmt = select([self.profiles]).where(self.profiles.c.id == identifier)\ .order_by(self.profiles.c.material, self.profiles.c.vendor) result = self.conn.execute(stmt) return self._result_to_dict(result, one=True) def create_profile(self, data): with self.lock, self.conn.begin(): stmt = insert(self.profiles)\ .values(vendor=data["vendor"], material=data["material"], density=data["density"], diameter=data["diameter"]) result = self.conn.execute(stmt) data["id"] = result.lastrowid return data def update_profile(self, identifier, data): with self.lock, self.conn.begin(): stmt = update(self.profiles).where(self.profiles.c.id == identifier)\ .values(vendor=data["vendor"], material=data["material"], density=data["density"], diameter=data["diameter"]) self.conn.execute(stmt) return data def delete_profile(self, identifier): with self.lock, self.conn.begin(): stmt = delete( self.profiles).where(self.profiles.c.id == identifier) self.conn.execute(stmt) # spools def _build_spool_dict(self, row, column_names): spool = dict(profile=dict()) for i, value in enumerate(row): if i < len(self.spools.columns): spool[column_names[i]] = value else: spool["profile"][column_names[i]] = value del spool["profile_id"] return spool def get_all_spools(self): with self.lock, self.conn.begin(): j = self.spools.join( self.profiles, self.spools.c.profile_id == self.profiles.c.id) stmt = select([self.spools, self.profiles ]).select_from(j).order_by(self.spools.c.name) result = self.conn.execute(stmt) return [ self._build_spool_dict(row, row.keys()) for row in result.fetchall() ] def get_spools_lastmodified(self): with self.lock, self.conn.begin(): stmt = select([func.max(self.modifications.c.changed_at)])\ .where(self.modifications.c.table_name.in_(["spools", "profiles"])) return self.conn.execute(stmt).scalar() def get_spool(self, identifier): with self.lock, self.conn.begin(): j = self.spools.join( self.profiles, self.spools.c.profile_id == self.profiles.c.id) stmt = select([self.spools, self.profiles]).select_from(j)\ .where(self.spools.c.id == identifier).order_by(self.spools.c.name) result = self.conn.execute(stmt) row = result.fetchone() return self._build_spool_dict(row, row.keys()) if row is not None else None def create_spool(self, data): with self.lock, self.conn.begin(): stmt = insert(self.spools)\ .values(name=data["name"], cost=data["cost"], weight=data["weight"], used=data["used"], temp_offset=data["temp_offset"], profile_id=data["profile"]["id"]) result = self.conn.execute(stmt) data["id"] = result.lastrowid return data def update_spool(self, identifier, data): with self.lock, self.conn.begin(): stmt = update(self.spools).where(self.spools.c.id == identifier)\ .values(name=data["name"], cost=data["cost"], weight=data["weight"], used=data["used"], temp_offset=data["temp_offset"], profile_id=data["profile"]["id"]) self.conn.execute(stmt) return data def delete_spool(self, identifier): with self.lock, self.conn.begin(): stmt = delete(self.spools).where(self.spools.c.id == identifier) self.conn.execute(stmt) # selections def _build_selection_dict(self, row, column_names): sel = dict(spool=dict(profile=dict())) for i, value in enumerate(row): if i < len(self.selections.columns): sel[column_names[i]] = value elif i < len(self.selections.columns) + len(self.spools.columns): sel["spool"][column_names[i]] = value else: sel["spool"]["profile"][column_names[i]] = value del sel["spool_id"] del sel["spool"]["profile_id"] return sel def get_all_selections(self, client_id): with self.lock, self.conn.begin(): j1 = self.selections.join( self.spools, self.selections.c.spool_id == self.spools.c.id) j2 = j1.join(self.profiles, self.spools.c.profile_id == self.profiles.c.id) stmt = select([self.selections, self.spools, self.profiles]).select_from(j2)\ .where(self.selections.c.client_id == client_id).order_by(self.selections.c.tool) result = self.conn.execute(stmt) return [ self._build_selection_dict(row, row.keys()) for row in result.fetchall() ] def get_selection(self, identifier, client_id): with self.lock, self.conn.begin(): j1 = self.selections.join( self.spools, self.selections.c.spool_id == self.spools.c.id) j2 = j1.join(self.profiles, self.spools.c.profile_id == self.profiles.c.id) stmt = select([self.selections, self.spools, self.profiles]).select_from(j2)\ .where((self.selections.c.tool == identifier) & (self.selections.c.client_id == client_id)) result = self.conn.execute(stmt) row = result.fetchone() return self._build_selection_dict( row, row.keys()) if row is not None else dict(tool=identifier, spool=None) def update_selection(self, identifier, client_id, data): with self.lock, self.conn.begin(): values = dict() if self.engine.dialect.name == self.DIALECT_SQLITE: stmt = insert(self.selections).prefix_with("OR REPLACE")\ .values(tool=identifier, client_id=client_id, spool_id=data["spool"]["id"]) elif self.engine.dialect.name == self.DIALECT_POSTGRESQL: stmt = pg_insert(self.selections)\ .values(tool=identifier, client_id=client_id, spool_id=data["spool"]["id"])\ .on_conflict_do_update(constraint="selections_pkey", set_=dict(spool_id=data["spool"]["id"])) self.conn.execute(stmt) return self.get_selection(identifier, client_id) def export_data(self, dirpath): def to_csv(table): with self.lock, self.conn.begin(): result = self.conn.execute(select([table])) filepath = os.path.join(dirpath, table.name + ".csv") with io.open(filepath, mode="w", encoding="utf-8") as csv_file: csv_writer = csv.writer(csv_file) csv_writer.writerow(table.columns.keys()) csv_writer.writerows(result) tables = [self.profiles, self.spools] for t in tables: to_csv(t) def import_data(self, dirpath): def from_csv(table): filepath = os.path.join(dirpath, table.name + ".csv") with io.open(filepath, mode="r", encoding="utf-8") as csv_file: csv_reader = csv.reader(csv_file) header = next(csv_reader) with self.lock, self.conn.begin(): for row in csv_reader: values = dict(zip(header, row)) if self.engine.dialect.name == self.DIALECT_SQLITE: identifier = values[table.c.id] # try to update entry stmt = update(table).values(values).where( table.c.id == identifier) if self.conn.execute(stmt).rowcount == 0: # identifier doesn't match any => insert new entry stmt = insert(table).values(values) self.conn.execute(stmt) elif self.engine.dialect.name == self.DIALECT_POSTGRESQL: stmt = pg_insert(table).values(values)\ .on_conflict_do_update(index_elements=[table.c.id], set_=values) self.conn.execute(stmt) if self.DIALECT_POSTGRESQL == self.engine.dialect.name: # update sequences self.conn.execute( text( "SELECT setval('profiles_id_seq', max(id)) FROM profiles" )) self.conn.execute( text( "SELECT setval('spools_id_seq', max(id)) FROM spools" )) tables = [self.profiles, self.spools] for t in tables: from_csv(t) # helper def _result_to_dict(self, result, one=False): if one: row = result.fetchone() return dict(row) if row is not None else None else: return [dict(row) for row in result.fetchall()]
# ### Join # join(right, onclause=None, isouter=False, full=False) # 예제 # In[83]: from sqlalchemy import join # join하면 자동으로 foreignkey를 확인해서 연결해준다 # In[84]: print(users.join(addresses)) # In[85]: print(users.join(addresses, users.c.id == addresses.c.user_id)) # In[87]: query = select([users.c.id, users.c.fullname, addresses.c.email_address]).select_from(users.join(addresses)) # In[89]: result = conn.execute(query).fetchall() # In[90]: