class TrialModelTestCase(unittest.TestCase): @classmethod def setUpClass(cls): global transaction, connection, engine # Connect to the database and create the schema within a transaction engine = create_engine(TEST_DATABASE_URI) connection = engine.connect() transaction = connection.begin() Trial.metadata.create_all(connection) # Load test trials fixtures from xml files nct_ids = ['NCT02034110', 'NCT00001160', 'NCT00001163'] cls.trials = load_sample_trials(nct_ids) @classmethod def tearDownClass(cls): # Roll back the top level transaction and disconnect from the database transaction.rollback() connection.close() engine.dispose() def setUp(self): self.__transaction = connection.begin_nested() self.session = Session(connection) def tearDown(self): self.session.close() self.__transaction.rollback() def test_add(self): trial = Trial(ct_dict=self.trials[0]) self.session.add(trial)
class TestSessions(TestCase): plugins = [] def test_multiple_connections(self): self.session2 = Session(bind=self.engine.connect()) article = self.Article(name=u'Session1 article') article2 = self.Article(name=u'Session2 article') self.session.add(article) self.session2.add(article2) self.session.flush() self.session2.flush() self.session.commit() self.session2.commit() assert article.versions[-1].transaction_id == 1 assert article2.versions[-1].transaction_id == 2 def test_manual_transaction_creation(self): uow = versioning_manager.unit_of_work(self.session) transaction = uow.create_transaction(self.session) self.session.flush() assert transaction.id == 1 article = self.Article(name=u'Session1 article') self.session.add(article) self.session.flush() assert uow.current_transaction.id == 1 self.session.commit() assert article.versions[-1].transaction_id == 1 def test_commit_without_objects(self): self.session.commit()
def test_session_with_external_transaction(self): conn = self.engine.connect() t = conn.begin() session = Session(bind=conn) article = self.Article(name=u'My Session Article') session.add(article) session.flush() session.close() t.rollback() conn.close()
def set_new_setting(session: Session, name: str, value): if isinstance(value, str) or isinstance(value, int) or isinstance(value, float) or isinstance(value, bool): value_class = str(value.__class__.__name__) value = str(value) else: value_class = "pickle" value = pickle.dumps(value) sett = session.query(Setting).filter_by(name=name).first() if not sett: sett = Setting(name=name, value=value, value_class=value_class) session.add(sett) else: sett.value = value sett.value_class = value_class return True
def init_empty_site(self, dbsession: Session, user: UserMixin): """When the first user signs up build the admin groups and make the user member of it. Make the first member of the site to be admin and superuser. """ # Try to reflect related group class based on User model i = inspection.inspect(user.__class__) Group = i.relationships["groups"].mapper.entity # Do we already have any groups... if we do we probably don'¨t want to init again if dbsession.query(Group).count() > 0: return g = Group(name=Group.DEFAULT_ADMIN_GROUP_NAME) dbsession.add(g) g.users.append(user)
def upgrade(migrate_engine): """ Upgrade operations go here. Don't create your own engine; bind migrate_engine to your metadata """ #========================================================================== # USER LOGS #========================================================================== from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog tbl = UserLog.__table__ username = Column("username", String(255, convert_unicode=False, assert_unicode=None), nullable=True, unique=None, default=None) # create username column username.create(table=tbl) _Session = Session() ## after adding that column fix all usernames users_log = _Session.query(UserLog)\ .options(joinedload(UserLog.user))\ .options(joinedload(UserLog.repository)).all() for entry in users_log: entry.username = entry.user.username _Session.add(entry) _Session.commit() #alter username to not null from rhodecode.lib.dbmigrate.schema.db_1_5_0 import UserLog tbl_name = UserLog.__tablename__ tbl = Table(tbl_name, MetaData(bind=migrate_engine), autoload=True, autoload_with=migrate_engine) col = tbl.columns.username # remove nullability from revision field col.alter(nullable=False)
def test_create_object(self): """ Test the method def _create_object(self, obj_type, sync_key, local_id=None, pvc_id=None) """ obj_type = "Network" sync_key = utils.gen_network_sync_key( self.fakePowerVCNetwork.powerNetInstance) local_id = self.fakeOSNetwork.fakeOSNetworkInstance['id'] pvc_id = self.fakePowerVCNetwork.powerNetInstance['id'] inputPowerVCMObj = model.PowerVCMapping(obj_type, sync_key) self.aMox.StubOutWithMock(session, 'begin') session.begin(subtransactions=True).AndReturn(transaction(None, None)) self.aMox.StubOutWithMock(model, 'PowerVCMapping') model.PowerVCMapping(obj_type, sync_key).AndReturn(inputPowerVCMObj) self.aMox.StubOutWithMock(session, 'add') session.add(inputPowerVCMObj).AndReturn("") self.aMox.ReplayAll() self.powervcagentdb._create_object( obj_type, sync_key, update_data=None, local_id=local_id, pvc_id=pvc_id) self.aMox.VerifyAll() self.assertEqual( self.powerVCMapping.local_id, inputPowerVCMObj.local_id) self.assertEqual(self.powerVCMapping.pvc_id, inputPowerVCMObj.pvc_id) self.assertEqual(self.powerVCMapping.status, inputPowerVCMObj.status) self.aMox.UnsetStubs()
class CrawlProcessor(object): __VERSION__ = "CrawlProcessor-0.2.1" def __init__(self, engine, redis_server, stop_list="keyword_filter.txt"): if type(engine) == types.StringType: logging.info("Using connection string '%s'" % (engine,)) new_engine = create_engine(engine, encoding='utf-8', isolation_level="READ COMMITTED") if "sqlite:" in engine: logging.debug("Setting text factory for unicode compat.") new_engine.raw_connection().connection.text_factory = str self._engine = new_engine else: logging.info("Using existing engine...") self._engine = engine logging.info("Binding session...") self._session = Session(bind=self._engine, autocommit = False) if type(stop_list) == types.StringType: stop_list_fp = open(stop_list) else: stop_list_fp = stop_list self.stop_list = set([]) for line in stop_list_fp: self.stop_list.add(line.strip()) self.cls = DocumentClassifier() self.dc = DomainController(self._engine, self._session) self.ac = ArticleController(self._engine, self._session) self.ex = extract.TermExtractor() self.kwc = KeywordController(self._engine, self._session) self.swc = SoftwareVersionsController(self._engine, self._session) self.redis_kw = redis.Redis(host=redis_server, port=6379, db=1) self.redis_dm = redis.Redis(host=redis_server, port=6379, db=2) dm_session = Session(bind=self._engine, autocommit = False) self.drw = DomainResolutionWorker(dm_session, self.redis_dm) def _check_processed(self, item): crawl_id, record = item headers, content, url, date_crawled, content_type = record path = self.ac.get_path_fromurl(url) domain_identifier = None logging.info("_check_processed: retrieving domain...") domain_key = self.dc.get_Domain_key(url) while domain_identifier == None: domain_identifier = self.drw.get_domain(domain_key) it = self._session.query(Article).filter_by(crawl_id = crawl_id).filter_by(domain_id = domain_identifier).filter_by(path = path) try: it = it.one() logging.error("%s: already processed", url) return False except sqlalchemy.orm.exc.MultipleResultsFound: logging.error("%s: appears to have been already processed multiple times", url) return False except sqlalchemy.orm.exc.NoResultFound: logging.info("%s: hasn't been processed yet", url) return True def process_record(self, item): if len(item) != 2: raise ValueError(item) if not self._check_processed(item): return None ret, retries = None, 2 while ret == None and retries > 0: try: retries -= 1 ret = self._process_record(item) except Exception as ex: import traceback print >> sys.stderr, ex traceback.print_exc() raise ex if ret == False: return None return ret def _process_record(self, item_arg): crawl_id, record = item_arg headers, content, url, date_crawled, content_type = record assert headers is not None assert content is not None assert url is not None assert date_crawled is not None assert content_type is not None status = "Processed" # Fix for a seg-fault if "nasa.gov" in url: return False # Sort out the domain domain_identifier = None logging.info("Retrieving domain...") domain_key = self.dc.get_Domain_key(url) while domain_identifier == None: domain_identifier = self.drw.get_domain(domain_key) domain = self._session.query(Domain).get(domain_identifier) assert domain is not None # Build database objects path = self.ac.get_path_fromurl(url) article = Article(path, date_crawled, crawl_id, domain, status) self._session.add(article) classified_by = self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__) assert classified_by is not None if content_type != 'text/html': logging.error("Unsupported content type: %s", str(content_type)) article.status = "UnsupportedType" return False # Start the async transaction to get the plain text worker_req_thread = BoilerPipeWorker(content) worker_req_thread.start() # Whilst that's executing, parse the document logging.info("Parsing HTML...") html = BeautifulSoup(content) if html is None or html.body is None: article.status = "NoContent" return False # Extract the dates date_dict = pydate.get_dates(html) if len(date_dict) == 0: status = "NoDates" # Detect the language lang, lang_certainty = langid.classify(content) # Wait for the BoilerPipe thread to complete worker_req_thread.join() logging.debug(worker_req_thread.result) logging.debug(worker_req_thread.version) if worker_req_thread.result == None: article.status = "NoContent" return False # If the language isn't English, skip it if lang != "en": logging.info("language: %s with certainty %.2f - skipping...", lang, lang_certainty) article.status = "LanguageError" # Replace with something appropriate return False content = worker_req_thread.result.encode('ascii', 'ignore') # Headline extraction h_counter = 6 headline = None while h_counter > 0: tag = "h%d" % (h_counter,) found = False for node in html.findAll(tag): if node.text in content: headline = node.text found = True break if found: break h_counter -= 1 # Run keyword extraction keywords = self.ex(content) kset = KeywordSet(self.stop_list) nnp_sets_scored = set([]) for word, freq, amnt in sorted(keywords): try: nnp_sets_scored.add((word, freq)) except ValueError: break nnp_adj = set([]) nnp_set = set([]) nnp_vector = [] for sentence in sent_tokenize(content): text = nltk.word_tokenize(sentence) pos = nltk.pos_tag(text) pos_groups = itertools.groupby(pos, lambda x: x[1]) for k, g in pos_groups: if k != 'NNP': continue nnp_list = [word for word, speech in g] nnp_buf = [] for item in nnp_list: nnp_set.add(item) nnp_buf.append(item) nnp_vector.append(item) for i, j in zip(nnp_buf[0:-1], nnp_buf[1:]): nnp_adj.add((i, j)) nnp_vector = filter(lambda x: x.lower() not in self.stop_list, nnp_vector) nnp_counter = Counter(nnp_vector) for word in nnp_set: score = nnp_counter[word] nnp_sets_scored.add((item, score)) for item, score in sorted(nnp_sets_scored, key=lambda x: x[1], reverse=True): try: if type(item) == types.ListType or type(item) == types.TupleType: kset.add(' '.join(item)) else: kset.add(item) except ValueError: break scored_nnp_adj = [] for item1, item2 in nnp_adj: score = nnp_counter[item1] + nnp_counter[item2] scored_nnp_adj.append((item1, item2, score)) nnp_adj = [] for item1, item2, score in sorted(scored_nnp_adj, key=lambda x: x[1], reverse=True): if len(nnp_adj) < KEYWORD_LIMIT: nnp_adj.append((item1, item2)) else: break # Generate list of all keywords keywords = set([]) for keyword in kset: try: k = Keyword(keyword) keywords.add(k) except ValueError as ex: logging.error(ex) continue for item1, item2 in nnp_adj: try: k = Keyword(item1) keywords.add(k) except ValueError as ex: logging.error(ex) try: k = Keyword(item2) keywords.add(k) except ValueError as ex: logging.error(ex) # Resolve keyword identifiers keyword_resolution_worker = KeywordResolutionWorker(set([k.word for k in keywords]), self.redis_kw) keyword_resolution_worker.start() # Run sentiment analysis trace = [] features = self.cls.classify(worker_req_thread.result, trace) label, length, classified, pos_sentences, neg_sentences,\ pos_phrases, neg_phrases = features[0:7] # Convert Pysen's model into database models try: doc = Document(article.id, label, length, pos_sentences, neg_sentences, pos_phrases, neg_phrases, headline) except ValueError as ex: logging.error(ex) logging.error("Skipping this document...") article.status = "ClassificationError" return False self._session.add(doc) extracted_phrases = set([]) for sentence, score, phrase_trace in trace: sentence_type = "Unknown" for node in html.findAll(text=True): if sentence.text in node.strip(): sentence_type = node.parent.name.upper() break if sentence_type not in ["H1", "H2", "H3", "H4", "H5", "H6", "P", "Unknown"]: sentence_type = "Other" label, average, prob, pos, neg, probs, _scores = score s = Sentence(doc, label, average, prob, sentence_type) self._session.add(s) for phrase, prob, score, label in phrase_trace: p = Phrase(s, score, prob, label) self._session.add(p) extracted_phrases.add((phrase, p)) # Wait for keyword resolution to finish keyword_resolution_worker.join() keyword_mapping = keyword_resolution_worker.out_keywords # Associate extracted keywords with phrases keyword_objects, short_keywords = kset.convert(keyword_mapping, self.kwc) for k in keyword_objects: self._session.merge(k) for p, p_obj in extracted_phrases: for k in keyword_objects: if k.word in p.get_text(): nk = KeywordIncidence(k, p_obj) # Save the keyword adjacency list for i, j in kset.convert_adj_tuples(nnp_adj, keyword_mapping, self.kwc): self._session.merge(i) self._session.merge(j) kwa = KeywordAdjacency(i, j, doc) self._session.add(kwa) # Build date objects for key in date_dict: rec = date_dict[key] if "dates" not in rec: logging.error("OK: 'dates' is not in a pydate result record.") continue dlen = len(rec["dates"]) if rec["text"] not in content: logging.debug("'%s' is not in %s", rec["text"], content) continue if dlen > 1: for date, day_first, year_first in rec["dates"]: try: dobj = AmbiguousDate(date, doc, day_first, year_first, rec["prep"], key) except ValueError as ex: logging.error(ex) continue self._session.add(dobj) elif dlen == 1: for date, day_first, year_first in rec["dates"]: dobj = CertainDate(date, doc, key) self._session.add(dobj) else: logging.error("'dates' in a pydate result set contains no records.") # Process links for link in html.findAll('a'): if not link.has_attr("href"): logging.debug("skipping %s: no href", link) continue process = True for node in link.findAll(text=True): if node not in worker_req_thread.result: process = False break if not process: logging.debug("skipping %s because it's not in the body text", link) break href, junk, junk = link["href"].partition("#") if "http://" in href: try: domain_id = None domain_key = self.dc.get_Domain_key(href) while domain_id is None: domain_id = self.drw.get_domain(domain_key) assert domain_id is not None href_domain = self._session.query(Domain).get(domain_id) except ValueError as ex: logging.error(ex) logging.error("Skipping this link") continue href_path = self.ac.get_path_fromurl(href) lnk = AbsoluteLink(doc, href_domain, href_path) self._session.add(lnk) logging.debug("Adding: %s", lnk) else: href_path = href try: lnk = RelativeLink(doc, href_path) except ValueError as ex: logging.error(ex) logging.error("Skipping link") continue self._session.add(lnk) logging.debug("Adding: %s", lnk) # Construct software involvment records self_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(self.__VERSION__), "Processed", doc) date_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pydate.__VERSION__), "Dated", doc) clas_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__), "Classified", doc) extr_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(worker_req_thread.version), "Extracted", doc) for sw in [self_sir, date_sir, clas_sir, extr_sir]: self._session.merge(sw, load=True) logging.debug("Domain: %s", domain) logging.debug("Path: %s", path) article.status = status # Commit to database, return True on success try: self._session.commit() except OperationalError as ex: logging.error(ex) self._session.rollback() return None return article.id def finalize(self): self._session.commit()
def test_export(session: Session) -> None: """ Test exporting a dataset. """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.export import ExportDatasetsCommand from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") session.add(database) session.flush() columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), TableColumn(column_name="user_id", type="INTEGER"), TableColumn(column_name="revenue", type="INTEGER"), TableColumn(column_name="expenses", type="INTEGER"), TableColumn( column_name="profit", type="INTEGER", expression="revenue-expenses", extra=json.dumps({"certified_by": "User"}), ), ] metrics = [ SqlMetric( metric_name="cnt", expression="COUNT(*)", extra=json.dumps({"warning_markdown": None}), ), ] sqla_table = SqlaTable( table_name="my_table", columns=columns, metrics=metrics, main_dttm_col="ds", database=database, offset=-8, description="This is the description", is_featured=1, cache_timeout=3600, schema="my_schema", sql=None, params=json.dumps( { "remote_id": 64, "database_name": "examples", "import_time": 1606677834, } ), perm=None, filter_select_enabled=1, fetch_values_predicate="foo IN (1, 2)", is_sqllab_view=0, # no longer used? template_params=json.dumps({"answer": "42"}), schema_perm=None, extra=json.dumps({"warning_markdown": "*WARNING*"}), ) export = list( ExportDatasetsCommand._export(sqla_table) # pylint: disable=protected-access ) assert export == [ ( "datasets/my_database/my_table.yaml", f"""table_name: my_table main_dttm_col: ds description: This is the description default_endpoint: null offset: -8 cache_timeout: 3600 schema: my_schema sql: null params: remote_id: 64 database_name: examples import_time: 1606677834 template_params: answer: '42' filter_select_enabled: 1 fetch_values_predicate: foo IN (1, 2) extra: warning_markdown: '*WARNING*' uuid: null metrics: - metric_name: cnt verbose_name: null metric_type: null expression: COUNT(*) description: null d3format: null extra: warning_markdown: null warning_text: null columns: - column_name: profit verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: revenue-expenses description: null python_date_format: null extra: certified_by: User - column_name: ds verbose_name: null is_dttm: 1 is_active: null type: TIMESTAMP advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null - column_name: user_id verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null - column_name: expenses verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null - column_name: revenue verbose_name: null is_dttm: null is_active: null type: INTEGER advanced_data_type: null groupby: null filterable: null expression: null description: null python_date_format: null extra: null version: 1.0.0 database_uuid: {database.uuid} """, ), ( "databases/my_database.yaml", f"""database_name: my_database sqlalchemy_uri: sqlite:// cache_timeout: null expose_in_sqllab: true allow_run_async: false allow_ctas: false allow_cvas: false allow_file_upload: false extra: metadata_params: {{}} engine_params: {{}} metadata_cache_timeout: {{}} schemas_allowed_for_file_upload: [] uuid: {database.uuid} version: 1.0.0 """, ), ]
def _clean(cls, obj: DBT, session: Session, base: Type[DBT]): back = obj.as_dict() session.delete(obj) session.commit() session.add(base(**back)) session.commit()
def add(self, instance, **kwargs): SessionBase.add(self, instance, **kwargs) return instance
def test_import_dataset(app_context: None, session: Session) -> None: """ Test importing a dataset. """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.importers.v1.utils import import_dataset from superset.datasets.schemas import ImportV1DatasetSchema from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") session.add(database) session.flush() dataset_uuid = uuid.uuid4() config = { "table_name": "my_table", "main_dttm_col": "ds", "description": "This is the description", "default_endpoint": None, "offset": -8, "cache_timeout": 3600, "schema": "my_schema", "sql": None, "params": { "remote_id": 64, "database_name": "examples", "import_time": 1606677834, }, "template_params": { "answer": "42", }, "filter_select_enabled": True, "fetch_values_predicate": "foo IN (1, 2)", "extra": { "warning_markdown": "*WARNING*" }, "uuid": dataset_uuid, "metrics": [{ "metric_name": "cnt", "verbose_name": None, "metric_type": None, "expression": "COUNT(*)", "description": None, "d3format": None, "extra": { "warning_markdown": None }, "warning_text": None, }], "columns": [{ "column_name": "profit", "verbose_name": None, "is_dttm": None, "is_active": None, "type": "INTEGER", "groupby": None, "filterable": None, "expression": "revenue-expenses", "description": None, "python_date_format": None, "extra": { "certified_by": "User", }, }], "database_uuid": database.uuid, "database_id": database.id, } sqla_table = import_dataset(session, config) assert sqla_table.table_name == "my_table" assert sqla_table.main_dttm_col == "ds" assert sqla_table.description == "This is the description" assert sqla_table.default_endpoint is None assert sqla_table.offset == -8 assert sqla_table.cache_timeout == 3600 assert sqla_table.schema == "my_schema" assert sqla_table.sql is None assert sqla_table.params == json.dumps({ "remote_id": 64, "database_name": "examples", "import_time": 1606677834 }) assert sqla_table.template_params == json.dumps({"answer": "42"}) assert sqla_table.filter_select_enabled is True assert sqla_table.fetch_values_predicate == "foo IN (1, 2)" assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}' assert sqla_table.uuid == dataset_uuid assert len(sqla_table.metrics) == 1 assert sqla_table.metrics[0].metric_name == "cnt" assert sqla_table.metrics[0].verbose_name is None assert sqla_table.metrics[0].metric_type is None assert sqla_table.metrics[0].expression == "COUNT(*)" assert sqla_table.metrics[0].description is None assert sqla_table.metrics[0].d3format is None assert sqla_table.metrics[0].extra == '{"warning_markdown": null}' assert sqla_table.metrics[0].warning_text is None assert len(sqla_table.columns) == 1 assert sqla_table.columns[0].column_name == "profit" assert sqla_table.columns[0].verbose_name is None assert sqla_table.columns[0].is_dttm is None assert sqla_table.columns[0].is_active is None assert sqla_table.columns[0].type == "INTEGER" assert sqla_table.columns[0].groupby is None assert sqla_table.columns[0].filterable is None assert sqla_table.columns[0].expression == "revenue-expenses" assert sqla_table.columns[0].description is None assert sqla_table.columns[0].python_date_format is None assert sqla_table.columns[0].extra == '{"certified_by": "User"}' assert sqla_table.database.uuid == database.uuid assert sqla_table.database.id == database.id
############################################################################### # scoped_session # The scoped_session object by default uses threading.local() as storage ############################################################################### # initial session_factory = sessionmaker(bind=engine) Session = scoped_session(session_factory) # only keep one session instance for each thread session_1 = Session() session_2 = Session() assert session_1 is session_2 # True # release and recreate session instance Session.remove() session_3 = Session() assert session_3 is session_1 # False # as a proxy Session.query() Session.add() Session.commit() # bind to a request scope def get_current_request(): # return an scope, for example, coroutine, thread or request pass Session = scoped_session(session_factory, scopefunc=get_current_request)
def store_new_user(username, firstname, surname, pw_hash, email): session = Session() new_user = User(username, firstname, surname, pw_hash, email) session.add(new_user) session.commit() session.close()
# features a generative interface whereby successive calls return a new Query # object, a copy of the former with additional criteria and options associated # with it. # Query objects are normally initially generated using the query() method # of Session. # query() takes a variable number of arguments, any combination of mapped # class, a Mapper object, an orm-enabled descriptor, or an AliasedClass object. ############################################################################### # basic query ############################################# # select query = Query(User) # insert session.add(User(id=1)) # aggregation (Query(User.department. func.sum(User.salary).albel('salary'), func.count('*').label('total_number')) .group_by(User.department)) # case Query(User.id, case([(User.salary < 1000, 1), (User.salary.between(1000,2000), 2)], else_=0).label('salary')) # join #####################################################
def export_sqlite(database_path: str, api_type, datas): metadata_directory = os.path.dirname(database_path) os.makedirs(metadata_directory, exist_ok=True) database_name = os.path.basename(database_path).replace(".db", "") cwd = os.getcwd() alembic_location = os.path.join(cwd, "database", "databases", database_name.lower()) db_helper.run_migrations(alembic_location, database_path) Session, engine = db_helper.create_database_session(database_path) db_collection = db_helper.database_collection() database = db_collection.database_picker(database_name) if not database: return database_session = Session() api_table = database.table_picker(api_type) if not api_table: return for post in datas: post_id = post["post_id"] postedAt = post["postedAt"] date_object = None if postedAt: if not isinstance(postedAt, datetime): date_object = datetime.strptime(postedAt, "%d-%m-%Y %H:%M:%S") else: date_object = postedAt result = database_session.query(api_table) post_db = result.filter_by(post_id=post_id).first() if not post_db: post_db = api_table() if api_type == "Messages": post_db.user_id = post["user_id"] post_db.post_id = post_id post_db.text = post["text"] if post["price"] is None: post["price"] = 0 post_db.price = post["price"] post_db.paid = post["paid"] post_db.archived = post["archived"] if date_object: post_db.created_at = date_object database_session.add(post_db) for media in post["medias"]: if media["media_type"] == "Texts": continue created_at = media["created_at"] if not isinstance(created_at, datetime): date_object = datetime.strptime(created_at, "%d-%m-%Y %H:%M:%S") else: date_object = postedAt media_id = media.get("media_id", None) result = database_session.query(database.media_table) media_db = result.filter_by(media_id=media_id).first() if not media_db: media_db = result.filter_by(filename=media["filename"], created_at=date_object).first() if not media_db: media_db = database.media_table() media_db.media_id = media_id media_db.post_id = post_id if "_sa_instance_state" in post: media_db.size = media["size"] media_db.downloaded = media["downloaded"] media_db.link = media["links"][0] media_db.preview = media.get("preview", False) media_db.directory = media["directory"] media_db.filename = media["filename"] media_db.api_type = api_type media_db.media_type = media["media_type"] media_db.linked = media.get("linked", None) if date_object: media_db.created_at = date_object database_session.add(media_db) print print print database_session.commit() database_session.close() return Session, api_type, database
def _add_index_entry(self, digest: Digest, session: SessionType) -> None: """ Helper method to add an index entry """ session.add( IndexEntry(digest_hash=digest.hash, digest_size_bytes=digest.size_bytes, accessed_timestamp=datetime.utcnow()))
def export_sqlite2(archive_path, datas, parent_type, legacy_fixer=False): metadata_directory = os.path.dirname(archive_path) os.makedirs(metadata_directory, exist_ok=True) cwd = os.getcwd() api_type: str = os.path.basename(archive_path).removesuffix(".db") database_path = archive_path database_name = parent_type if parent_type else api_type database_name = database_name.lower() db_collection = db_helper.database_collection() database = db_collection.database_picker(database_name) if not database: return alembic_location = os.path.join(cwd, "database", "databases", database_name) database_exists = os.path.exists(database_path) if database_exists: if os.path.getsize(database_path) == 0: os.remove(database_path) database_exists = False if not legacy_fixer: legacy_database_fixer(database_path, database, database_name, database_exists) db_helper.run_migrations(alembic_location, database_path) print Session, engine = db_helper.create_database_session(database_path) database_session = Session() api_table = database.api_table media_table = database.media_table for post in datas: post_id = post["post_id"] postedAt = post["postedAt"] date_object = None if postedAt: if not isinstance(postedAt, datetime): date_object = datetime.strptime(postedAt, "%d-%m-%Y %H:%M:%S") else: date_object = postedAt result = database_session.query(api_table) post_db = result.filter_by(post_id=post_id).first() if not post_db: post_db = api_table() post_db.post_id = post_id post_db.text = post["text"] if post["price"] is None: post["price"] = 0 post_db.price = post["price"] post_db.paid = post["paid"] post_db.archived = post["archived"] if date_object: post_db.created_at = date_object database_session.add(post_db) for media in post["medias"]: if media["media_type"] == "Texts": continue media_id = media.get("media_id", None) result = database_session.query(media_table) media_db = result.filter_by(media_id=media_id).first() if not media_db: media_db = result.filter_by(filename=media["filename"], created_at=date_object).first() if not media_db: media_db = media_table() if legacy_fixer: media_db.size = media["size"] media_db.downloaded = media["downloaded"] media_db.media_id = media_id media_db.post_id = post_id media_db.link = media["links"][0] media_db.preview = media.get("preview", False) media_db.directory = media["directory"] media_db.filename = media["filename"] media_db.api_type = api_type media_db.media_type = media["media_type"] media_db.linked = media.get("linked", None) if date_object: media_db.created_at = date_object database_session.add(media_db) print print print database_session.commit() database_session.close() return Session, api_type, database
def compile_hourly_statistics( instance: Recorder, session: Session, start: datetime ) -> None: """Compile hourly statistics. This will summarize 5-minute statistics for one hour: - average, min max is computed by a database query - sum is taken from the last 5-minute entry during the hour """ start_time = start.replace(minute=0) end_time = start_time + timedelta(hours=1) # Compute last hour's average, min, max summary: dict[str, StatisticData] = {} baked_query = instance.hass.data[STATISTICS_BAKERY]( lambda session: session.query(*QUERY_STATISTICS_SUMMARY_MEAN) ) baked_query += lambda q: q.filter( StatisticsShortTerm.start >= bindparam("start_time") ) baked_query += lambda q: q.filter(StatisticsShortTerm.start < bindparam("end_time")) baked_query += lambda q: q.group_by(StatisticsShortTerm.metadata_id) baked_query += lambda q: q.order_by(StatisticsShortTerm.metadata_id) stats = execute( baked_query(session).params(start_time=start_time, end_time=end_time) ) if stats: for stat in stats: metadata_id, _mean, _min, _max = stat summary[metadata_id] = { "start": start_time, "mean": _mean, "min": _min, "max": _max, } # Get last hour's last sum if instance._db_supports_row_number: # pylint: disable=[protected-access] subquery = ( session.query(*QUERY_STATISTICS_SUMMARY_SUM) .filter(StatisticsShortTerm.start >= bindparam("start_time")) .filter(StatisticsShortTerm.start < bindparam("end_time")) .subquery() ) query = ( session.query(subquery) .filter(subquery.c.rownum == 1) .order_by(subquery.c.metadata_id) ) stats = execute(query.params(start_time=start_time, end_time=end_time)) if stats: for stat in stats: metadata_id, start, last_reset, state, _sum, _ = stat if metadata_id in summary: summary[metadata_id].update( { "last_reset": process_timestamp(last_reset), "state": state, "sum": _sum, } ) else: summary[metadata_id] = { "start": start_time, "last_reset": process_timestamp(last_reset), "state": state, "sum": _sum, } else: baked_query = instance.hass.data[STATISTICS_BAKERY]( lambda session: session.query(*QUERY_STATISTICS_SUMMARY_SUM_LEGACY) ) baked_query += lambda q: q.filter( StatisticsShortTerm.start >= bindparam("start_time") ) baked_query += lambda q: q.filter( StatisticsShortTerm.start < bindparam("end_time") ) baked_query += lambda q: q.order_by( StatisticsShortTerm.metadata_id, StatisticsShortTerm.start.desc() ) stats = execute( baked_query(session).params(start_time=start_time, end_time=end_time) ) if stats: for metadata_id, group in groupby(stats, lambda stat: stat["metadata_id"]): # type: ignore[no-any-return] ( metadata_id, last_reset, state, _sum, ) = next(group) if metadata_id in summary: summary[metadata_id].update( { "start": start_time, "last_reset": process_timestamp(last_reset), "state": state, "sum": _sum, } ) else: summary[metadata_id] = { "start": start_time, "last_reset": process_timestamp(last_reset), "state": state, "sum": _sum, } # Insert compiled hourly statistics in the database for metadata_id, stat in summary.items(): session.add(Statistics.from_stats(metadata_id, stat))
class TypeDecoratorsTest(unittest.TestCase): def setUp(self): self.engine = create_engine("sqlite://") Base.metadata.create_all(self.engine) self.db = Session(bind=self.engine) def tearDown(self): self.db.query(TypesObject).delete() self.db.commit() self.db.close() def test_string_list(self): words = ['one', 'two', 'three', 'four', 'five'] obj = TypesObject() obj.words = words self.db.add(obj) self.db.commit() self.db.close() self.db = Session(bind=self.engine) obj = self.db.query(TypesObject).first() self.assertEqual(words, obj.words) def test_integer_list(self): numbers = [1, 5, 10, 15, 20] obj = TypesObject() obj.numbers = numbers self.db.add(obj) self.db.commit() self.db.close() self.db = Session(bind=self.engine) obj = self.db.query(TypesObject).first() self.assertEqual(numbers, obj.numbers) def test_string_wrapped_in_html(self): obj = TypesObject() obj.html_string1 = Markupable('<html>value</html>') self.db.add(obj) self.db.commit() self.db.close() self.db = Session(bind=self.engine) obj = self.db.query(TypesObject).first() self.assertIsInstance(obj.html_string1, Markup) self.assertEqual('<html>value</html>', obj.html_string1) def test_html_string(self): obj = TypesObject() obj.html_string2 = Markupable('<html>value</html>') self.db.add(obj) self.db.commit() self.db.close() self.db = Session(bind=self.engine) obj = self.db.query(TypesObject).first() self.assertIsInstance(obj.html_string2, Markup) self.assertEqual('<html>value</html>', obj.html_string2) def test_html_text(self): obj = TypesObject() text = "<html>" + "the sample_text " * 100 + "</html>" obj.html_text = Markupable(text) self.db.add(obj) self.db.commit() self.db.close() self.db = Session(bind=self.engine) obj = self.db.query(TypesObject).first() self.assertIsInstance(obj.html_text, Markup) self.assertEqual(text, obj.html_text) def test_html_custom_markup(self): obj = TypesObject() obj.html_custom = Markupable('<html> value </html>') self.db.add(obj) self.db.commit() self.db.close() self.db = Session(bind=self.engine) obj = self.db.query(TypesObject).first() self.assertIsInstance(obj.html_custom, CustomMarkup) self.assertEqual('<html> value </html>', obj.html_custom)
def user_library_state_update( self, update_task: DatabaseTask, session: Session, user_library_factory_txs, block_number, block_timestamp, block_hash, _ipfs_metadata, # prefix unused args with underscore to prevent pylint _blacklisted_cids, ) -> Tuple[int, Set]: """Return Tuple containing int representing number of User Library model state changes found in transaction and empty Set (to align with fn signature of other _state_update functions.""" empty_set: Set[int] = set() num_total_changes = 0 if not user_library_factory_txs: return num_total_changes, empty_set challenge_bus = update_task.challenge_event_bus block_datetime = datetime.utcfromtimestamp(block_timestamp) track_save_state_changes: Dict[int, Dict[int, Save]] = {} playlist_save_state_changes: Dict[int, Dict[int, Save]] = {} for tx_receipt in user_library_factory_txs: try: add_track_save( self, update_task.user_library_contract, update_task, session, tx_receipt, block_number, block_datetime, track_save_state_changes, ) add_playlist_save( self, update_task.user_library_contract, update_task, session, tx_receipt, block_number, block_datetime, playlist_save_state_changes, ) delete_track_save( self, update_task.user_library_contract, update_task, session, tx_receipt, block_number, block_datetime, track_save_state_changes, ) delete_playlist_save( self, update_task.user_library_contract, update_task, session, tx_receipt, block_number, block_datetime, playlist_save_state_changes, ) except Exception as e: logger.info("Error in user library transaction") txhash = update_task.web3.toHex(tx_receipt.transactionHash) blockhash = update_task.web3.toHex(block_hash) raise IndexingError( "user_library", block_number, blockhash, txhash, str(e) ) from e for user_id, track_ids in track_save_state_changes.items(): for track_id in track_ids: invalidate_old_save(session, user_id, track_id, SaveType.track) save = track_ids[track_id] session.add(save) dispatch_favorite(challenge_bus, save, block_number) num_total_changes += len(track_ids) for user_id, playlist_ids in playlist_save_state_changes.items(): for playlist_id in playlist_ids: invalidate_old_save( session, user_id, playlist_id, playlist_ids[playlist_id].save_type, ) save = playlist_ids[playlist_id] session.add(save) dispatch_favorite(challenge_bus, save, block_number) num_total_changes += len(playlist_ids) return num_total_changes, empty_set
def test_create_zero_screamer_chance(self): session = Session(connection) webm = WEBM(id_="q", screamer_chance=0) session.add(webm) session.commit()
class Catalog(object): def __init__(self, engine): self.engine = engine self.session = Session(bind=self.engine) @subtransaction def add_item(self, code, description, long_description=None): citem = models.Item(code=code, description=description, long_description=long_description) self.session.add(citem) return citem @subtransaction def add_items(self, items): for item in items: # KeyErrors not caught if missing required field try: long_desc = item['long description'] except KeyError: long_desc = None citem = models.Item(code=item['code'], description=item['description'], description_lower=item['description'].lower(), long_description=long_desc) primary=True for cat in item['categories']: q = self.session.query(models.Category).\ filter(models.Category.name == cat) cater = q.first() icater = models.ItemCategory(item=citem, category=cater, primary=primary) citem.categories.append(icater) primary=False self.session.add(citem) def add_category(self, name): self.add_categories([name]) @subtransaction def add_categories(self, names): for name in names: cater = models.Category(name=name) self.session.add(cater) # def remove_category: should make sure that zero use # def rename_category @subtransaction def add_item_category(self, item_id, category_id, primary=False): catitem = models.ItemCategory(item_id=item_id, category_id=category_id, primary=primary) self.session.add(catitem) return catitem # def remove_item_category def get_item(self, code, as_object=False): q = self.session.query(models.Item).filter(models.Item.code == code) item = q.first() if item is None: return None if as_object is True: return item return (item.id, item.code, item.description, item.long_description) def remove_item(self, code): self.session.begin(subtransactions=True) q = self.session.query(models.Item).filter(models.Item.code == code).\ delete() self.session.commit() @subtransaction def update_item(self, code, description=None, long_description=None, new_code=None): q = self.session.query(models.Item).filter(models.Item.code == code) item = q.first() if new_code is not None: item.code = new_code if description is not None: item.description = description item.description_lower = description.lower() if long_description is not None: item.long_description = long_description def get_stock(self, code, as_object=False): q = self.session.query(models.Item, models.StockItem).\ with_entities(models.StockItem).\ filter(models.Item.code == code).\ filter(models.Item.id == models.StockItem.item_id) res = q.first() if res is None: return None if as_object is True: return res return (res.id, res.price, res.count) @subtransaction def add_stock(self, items): for item in items: q = self.session.query(models.Item).\ filter(models.Item.code == item['code']) citem = q.first() stock = models.StockItem(item=citem, count=item['count'], price=item['price']) self.session.add(stock) @subtransaction def update_stock(self, code, count, price): q = self.session.query(models.Item, models.StockItem).\ with_entities(models.StockItem).\ join(models.StockItem.item).\ filter(models.Item.code == code) stock = q.first() if stock is None: q = self.session.query(models.Item).\ filter(models.Item.code == code) item = q.first() if item is not None: stock = models.StockItem(item=item, count=count, price=price) self.session.add(stock) else: raise KeyError('Unknown item code') else: stock.count += count stock.price = price @subtransaction def add_items_with_stock(self, items): categories = set() for item in items: for x in item['categories']: categories.add(x) self.add_categories(categories) self.add_items(items) self.add_stock(items) def list_items(self, sort_key='description', ascending=True, page=1, page_size=10): sq = self.session.query(models.Reservation.stock_item_id, func.sum(models.Reservation.count).\ label('reserved')).\ group_by(models.Reservation.stock_item_id).\ subquery() q = self.session.query(models.Item, models.Category, models.ItemCategory, models.StockItem, sq.c.reserved).\ with_entities(models.Item.code, models.Item.description, models.Category.name, models.StockItem.price, models.StockItem.count, sq.c.reserved).\ join(models.StockItem).\ join(models.ItemCategory, models.Item.id == models.ItemCategory.item_id).\ join(models.Category, models.Category.id == models.ItemCategory.category_id).\ filter(models.ItemCategory.primary == True).\ outerjoin(sq, models.StockItem.id == sq.c.stock_item_id) q = ordering(q, ascending, sort_key) q = pagination(q, page, page_size) res = [item_to_json(*x) for x in q] return res def search_items(self, prefix, price_range, sort_key='description', ascending=True, page=1, page_size=10): sq = self.session.query(models.Reservation.stock_item_id, func.sum(models.Reservation.count).\ label('reserved')).\ group_by(models.Reservation.stock_item_id).\ subquery() q = self.session.query(models.Item, models.Category, models.ItemCategory, models.StockItem, sq.c.reserved).\ with_entities(models.Item.code, models.Item.description, models.Category.name, models.StockItem.price, models.StockItem.count, sq.c.reserved).\ join(models.StockItem).\ join(models.ItemCategory, models.Item.id == models.ItemCategory.item_id).\ join(models.Category, models.Category.id == models.ItemCategory.category_id).\ filter(models.StockItem.price.between(*price_range), models.Item.description_lower.like('{:s}%'.format(prefix.lower())), models.ItemCategory.primary == True).\ outerjoin(sq, models.StockItem.id == sq.c.stock_item_id) q = ordering(q, ascending, sort_key) q = pagination(q, page, page_size) res = [item_to_json(*x) for x in q] return res def list_items_by_prices(self, prices, sort_key='price', prefix=None, ascending=True, page=1, page_size=10): pgs = pg_cases(prices) pg_case = sqla.case(pgs, else_ = -1).label('price_group') sq = self.session.query(models.Reservation.stock_item_id, func.sum(models.Reservation.count).\ label('reserved')).\ group_by(models.Reservation.stock_item_id).\ subquery() q = self.session.query(models.Item, models.Category, models.ItemCategory, models.StockItem, sq.c.reserved).\ with_entities(pg_case, models.Item.code, models.Item.description, models.Category.name, models.StockItem.price, models.StockItem.count, sq.c.reserved) if prefix is not None: q = q.filter(models.Item.description_lower.like('{:s}%'.format(prefix.lower()))) q = q.join(models.StockItem.item).\ outerjoin(sq, models.StockItem.id == sq.c.stock_item_id).\ filter(models.ItemCategory.item_id == models.Item.id, models.ItemCategory.category_id == models.Category.id, models.ItemCategory.primary == True, pg_case >= 0) q = pg_ordering(q, ascending) q = pagination(q, page, page_size) def to_dict(x): tmp = item_to_json(*x[1:]) tmp.update({'price_group': x[0]}) return tmp res = [to_dict(x) for x in q] return res def _get_reservations(self, stock_id): q = self.session.query(func.sum(models.Reservation.count)).\ filter(models.StockItem.id == stock_id).\ filter(models.StockItem.id == models.Reservation.stock_item_id) res = q.first() if res is None or res[0] is None: return 0 return res[0] def _get_reservation(self, basket_item_id): q = self.session.query(models.Basket, models.BasketItem, models.Reservation).\ with_entities(models.Reservation).\ join(models.Reservation.basket_item).\ filter(models.BasketItem.id == basket_item_id) res = q.first() return res @subtransaction def _update_reservation(self, stock, basket_item): reservations = self._get_reservations(basket_item.stock_item_id) # can reserve (scount - reservations) reservation = self._get_reservation(basket_item.id) if reservation is not None: rcount = min(basket_item.count, stock.count - reservations + reservation.count) reservation.count = rcount else: rcount = min(basket_item.count, stock.count - reservations) reservation = models.Reservation(stock_item=stock, basket_item=basket_item, count=rcount) self.session.add(reservation)
def test_sql_lab_insert_rls( mocker: MockerFixture, session: Session, app_context: None, ) -> None: """ Integration test for `insert_rls`. """ from flask_appbuilder.security.sqla.models import Role, User from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.models.core import Database from superset.models.sql_lab import Query from superset.security.manager import SupersetSecurityManager from superset.sql_lab import execute_sql_statement from superset.utils.core import RowLevelSecurityFilterType engine = session.connection().engine Query.metadata.create_all(engine) # pylint: disable=no-member connection = engine.raw_connection() connection.execute("CREATE TABLE t (c INTEGER)") for i in range(10): connection.execute("INSERT INTO t VALUES (?)", (i, )) cursor = connection.cursor() query = Query( sql="SELECT c FROM t", client_id="abcde", database=Database(database_name="test_db", sqlalchemy_uri="sqlite://"), schema=None, limit=5, select_as_cta_used=False, ) session.add(query) session.commit() admin = User( first_name="Alice", last_name="Doe", email="*****@*****.**", username="******", roles=[Role(name="Admin")], ) # first without RLS with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert (superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 0 | | 1 | 1 | | 2 | 2 | | 3 | 3 | | 4 | 4 |""".strip()) assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" # now with RLS rls = RowLevelSecurityFilter( name="sqllab_rls1", filter_type=RowLevelSecurityFilterType.REGULAR, tables=[SqlaTable(database_id=1, schema=None, table_name="t")], roles=[admin.roles[0]], group_key=None, clause="c > 5", ) session.add(rls) session.flush() mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) with override_user(admin): superset_result_set = execute_sql_statement( sql_statement=query.sql, query=query, session=session, cursor=cursor, log_params=None, apply_ctas=False, ) assert (superset_result_set.to_pandas_df().to_markdown() == """ | | c | |---:|----:| | 0 | 6 | | 1 | 7 | | 2 | 8 | | 3 | 9 |""".strip()) assert query.executed_sql == "SELECT c FROM t WHERE (t.c > 5)\nLIMIT 6"
class DatabaseTest(object): engine = None connection = None @classmethod def get_database_connection(cls): url = Configuration.database_url(test=True) engine, connection = SessionManager.initialize(url) return engine, connection @classmethod def setup_class(cls): # Initialize a temporary data directory. cls.engine, cls.connection = cls.get_database_connection() cls.old_data_dir = Configuration.data_directory cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp") Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir # Avoid CannotLoadConfiguration errors related to CDN integrations. Configuration.instance[Configuration.INTEGRATIONS] = Configuration.instance.get( Configuration.INTEGRATIONS, {} ) Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {} os.environ['TESTING'] = 'true' @classmethod def teardown_class(cls): # Destroy the database connection and engine. cls.connection.close() cls.engine.dispose() if cls.tmp_data_dir.startswith("/tmp"): logging.debug("Removing temporary directory %s" % cls.tmp_data_dir) shutil.rmtree(cls.tmp_data_dir) else: logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir) Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir if 'TESTING' in os.environ: del os.environ['TESTING'] def setup(self, mock_search=True): # Create a new connection to the database. self._db = Session(self.connection) self.transaction = self.connection.begin_nested() # Start with a high number so it won't interfere with tests that search for an age or grade self.counter = 2000 self.time_counter = datetime(2014, 1, 1) self.isbns = [ "9780674368279", "0636920028468", "9781936460236", "9780316075978" ] if mock_search: self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", DummyExternalSearchIndex) self.search_mock.start() else: self.search_mock = None # TODO: keeping this for now, but need to fix it bc it hits _isbn, # which pops an isbn off the list and messes tests up. so exclude # _ functions from participating. # also attempt to stop nosetest showing docstrings instead of function names. #for name, obj in inspect.getmembers(self): # if inspect.isfunction(obj) and obj.__name__.startswith('test_'): # obj.__doc__ = None def teardown(self): # Close the session. self._db.close() # Roll back all database changes that happened during this # test, whether in the session that was just closed or some # other session. self.transaction.rollback() # Remove any database objects cached in the model classes but # associated with the now-rolled-back session. Collection.reset_cache() ConfigurationSetting.reset_cache() DataSource.reset_cache() DeliveryMechanism.reset_cache() ExternalIntegration.reset_cache() Genre.reset_cache() Library.reset_cache() # Also roll back any record of those changes in the # Configuration instance. for key in [ Configuration.SITE_CONFIGURATION_LAST_UPDATE, Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE ]: if key in Configuration.instance: del(Configuration.instance[key]) if self.search_mock: self.search_mock.stop() def shortDescription(self): return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2. @property def _id(self): self.counter += 1 return self.counter @property def _str(self): return unicode(self._id) @property def _time(self): v = self.time_counter self.time_counter = self.time_counter + timedelta(days=1) return v @property def _isbn(self): return self.isbns.pop() @property def _url(self): return "http://foo.com/" + self._str def _patron(self, external_identifier=None, library=None): external_identifier = external_identifier or self._str library = library or self._default_library return get_one_or_create( self._db, Patron, external_identifier=external_identifier, library=library )[0] def _contributor(self, sort_name=None, name=None, **kw_args): name = sort_name or name or self._str return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args) def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None): if foreign_id: id = foreign_id else: id = self._str return Identifier.for_foreign_id(self._db, identifier_type, id)[0] def _edition(self, data_source_name=DataSource.GUTENBERG, identifier_type=Identifier.GUTENBERG_ID, with_license_pool=False, with_open_access_download=False, title=None, language="eng", authors=None, identifier_id=None, series=None, collection=None ): id = identifier_id or self._str source = DataSource.lookup(self._db, data_source_name) wr = Edition.for_foreign_id( self._db, source, identifier_type, id)[0] if not title: title = self._str wr.title = unicode(title) if series: wr.series = series if language: wr.language = language if authors is None: authors = self._str if isinstance(authors, basestring): authors = [authors] if authors != []: wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE) wr.author = unicode(authors[0]) for author in authors[1:]: wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE) if with_license_pool or with_open_access_download: pool = self._licensepool( wr, data_source_name=data_source_name, with_open_access_download=with_open_access_download, collection=collection ) pool.set_presentation_edition() return wr, pool return wr def _work(self, title=None, authors=None, genre=None, language=None, audience=None, fiction=True, with_license_pool=False, with_open_access_download=False, quality=0.5, series=None, presentation_edition=None, collection=None, data_source_name=None): """Create a Work. For performance reasons, this method does not generate OPDS entries or calculate a presentation edition for the new Work. Tests that rely on this information being present should call _slow_work() instead, which takes more care to present the sort of Work that would be created in a real environment. """ pools = [] if with_open_access_download: with_license_pool = True language = language or "eng" title = unicode(title or self._str) audience = audience or Classifier.AUDIENCE_ADULT if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name: # TODO: This is necessary because Gutenberg's childrens books # get filtered out at the moment. data_source_name = DataSource.OVERDRIVE elif not data_source_name: data_source_name = DataSource.GUTENBERG if fiction is None: fiction = True new_edition = False if not presentation_edition: new_edition = True presentation_edition = self._edition( title=title, language=language, authors=authors, with_license_pool=with_license_pool, with_open_access_download=with_open_access_download, data_source_name=data_source_name, series=series, collection=collection, ) if with_license_pool: presentation_edition, pool = presentation_edition if with_open_access_download: pool.open_access = True pools = [pool] else: pools = presentation_edition.license_pools work, ignore = get_one_or_create( self._db, Work, create_method_kwargs=dict( audience=audience, fiction=fiction, quality=quality), id=self._id) if genre: if not isinstance(genre, Genre): genre, ignore = Genre.lookup(self._db, genre, autocreate=True) work.genres = [genre] work.random = 0.5 work.set_presentation_edition(presentation_edition) if pools: # make sure the pool's presentation_edition is set, # bc loan tests assume that. if not work.license_pools: for pool in pools: work.license_pools.append(pool) for pool in pools: pool.set_presentation_edition() # This is probably going to be used in an OPDS feed, so # fake that the work is presentation ready. work.presentation_ready = True work.calculate_opds_entries(verbose=False) return work def add_to_materialized_view(self, works, true_opds=False): """Make sure all the works in `works` show up in the materialized view. :param true_opds: Generate real OPDS entries for each each work, rather than faking it. """ if not isinstance(works, list): works = [works] for work in works: if true_opds: work.calculate_opds_entries(verbose=False) else: work.presentation_ready = True work.simple_opds_entry = "<entry>an entry</entry>" self._db.commit() SessionManager.refresh_materialized_views(self._db) def _lane(self, display_name=None, library=None, parent=None, genres=None, languages=None, fiction=None ): display_name = display_name or self._str library = library or self._default_library lane, is_new = get_one_or_create( self._db, Lane, library=library, parent=parent, display_name=display_name, create_method_kwargs=dict(fiction=fiction) ) if is_new and parent: lane.priority = len(parent.sublanes)-1 if genres: if not isinstance(genres, list): genres = [genres] for genre in genres: if isinstance(genre, basestring): genre, ignore = Genre.lookup(self._db, genre) lane.genres.append(genre) if languages: if not isinstance(languages, list): languages = [languages] lane.languages = languages return lane def _slow_work(self, *args, **kwargs): """Create a work that closely resembles one that might be found in the wild. This is significantly slower than _work() but more reliable. """ work = self._work(*args, **kwargs) work.calculate_presentation_edition() work.calculate_opds_entries(verbose=False) return work def _add_generic_delivery_mechanism(self, license_pool): """Give a license pool a generic non-open-access delivery mechanism.""" data_source = license_pool.data_source identifier = license_pool.identifier content_type = Representation.EPUB_MEDIA_TYPE drm_scheme = DeliveryMechanism.NO_DRM LicensePoolDeliveryMechanism.set( data_source, identifier, content_type, drm_scheme, RightsStatus.IN_COPYRIGHT ) def _coverage_record(self, edition, coverage_source, operation=None, status=CoverageRecord.SUCCESS, collection=None, exception=None, ): if isinstance(edition, Identifier): identifier = edition else: identifier = edition.primary_identifier record, ignore = get_one_or_create( self._db, CoverageRecord, identifier=identifier, data_source=coverage_source, operation=operation, collection=collection, create_method_kwargs = dict( timestamp=datetime.utcnow(), status=status, exception=exception, ) ) return record def _work_coverage_record(self, work, operation=None, status=CoverageRecord.SUCCESS): record, ignore = get_one_or_create( self._db, WorkCoverageRecord, work=work, operation=operation, create_method_kwargs = dict( timestamp=datetime.utcnow(), status=status, ) ) return record def _licensepool(self, edition, open_access=True, data_source_name=DataSource.GUTENBERG, with_open_access_download=False, set_edition_as_presentation=False, collection=None): source = DataSource.lookup(self._db, data_source_name) if not edition: edition = self._edition(data_source_name) collection = collection or self._default_collection pool, ignore = get_one_or_create( self._db, LicensePool, create_method_kwargs=dict( open_access=open_access), identifier=edition.primary_identifier, data_source=source, collection=collection, availability_time=datetime.utcnow() ) if set_edition_as_presentation: pool.presentation_edition = edition if with_open_access_download: pool.open_access = True url = "http://foo.com/" + self._str media_type = Representation.EPUB_MEDIA_TYPE link, new = pool.identifier.add_link( Hyperlink.OPEN_ACCESS_DOWNLOAD, url, source, media_type ) # Add a DeliveryMechanism for this download pool.set_delivery_mechanism( media_type, DeliveryMechanism.NO_DRM, RightsStatus.GENERIC_OPEN_ACCESS, link.resource, ) representation, is_new = self._representation( url, media_type, "Dummy content", mirrored=True) link.resource.representation = representation else: # Add a DeliveryMechanism for this licensepool pool.set_delivery_mechanism( Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, RightsStatus.UNKNOWN, None ) pool.licenses_owned = pool.licenses_available = 1 return pool def _representation(self, url=None, media_type=None, content=None, mirrored=False): url = url or "http://foo.com/" + self._str repr, is_new = get_one_or_create( self._db, Representation, url=url) repr.media_type = media_type if media_type and content: repr.content = content repr.fetched_at = datetime.utcnow() if mirrored: repr.mirror_url = "http://foo.com/" + self._str repr.mirrored_at = datetime.utcnow() return repr, is_new def _customlist(self, foreign_identifier=None, name=None, data_source_name=DataSource.NYT, num_entries=1, entries_exist_as_works=True ): data_source = DataSource.lookup(self._db, data_source_name) foreign_identifier = foreign_identifier or self._str now = datetime.utcnow() customlist, ignore = get_one_or_create( self._db, CustomList, create_method_kwargs=dict( created=now, updated=now, name=name or self._str, description=self._str, ), data_source=data_source, foreign_identifier=foreign_identifier ) editions = [] for i in range(num_entries): if entries_exist_as_works: work = self._work(with_open_access_download=True) edition = work.presentation_edition else: edition = self._edition( data_source_name, title="Item %s" % i) edition.permanent_work_id="Permanent work ID %s" % self._str customlist.add_entry( edition, "Annotation %s" % i, first_appearance=now) editions.append(edition) return customlist, editions def _complaint(self, license_pool, type, source, detail, resolved=None): complaint, is_new = Complaint.register( license_pool, type, source, detail, resolved ) return complaint def _credential(self, data_source_name=DataSource.GUTENBERG, type=None, patron=None): data_source = DataSource.lookup(self._db, data_source_name) type = type or self._str patron = patron or self._patron() credential, is_new = Credential.persistent_token_create( self._db, data_source, type, patron ) return credential def _external_integration(self, protocol, goal=None, settings=None, libraries=None, **kwargs ): integration = None if not libraries: integration, ignore = get_one_or_create( self._db, ExternalIntegration, protocol=protocol, goal=goal ) else: if not isinstance(libraries, list): libraries = [libraries] # Try to find an existing integration for one of the given # libraries. for library in libraries: integration = ExternalIntegration.lookup( self._db, protocol, goal, library=libraries[0] ) if integration: break if not integration: # Otherwise, create a brand new integration specifically # for the library. integration = ExternalIntegration( protocol=protocol, goal=goal, ) integration.libraries.extend(libraries) self._db.add(integration) for attr, value in kwargs.items(): setattr(integration, attr, value) settings = settings or dict() for key, value in settings.items(): integration.set_setting(key, value) return integration def _delegated_patron_identifier( self, library_uri=None, patron_identifier=None, identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, identifier=None ): """Create a sample DelegatedPatronIdentifier""" library_uri = library_uri or self._url patron_identifier = patron_identifier or self._str if callable(identifier): make_id = identifier else: if not identifier: identifier = self._str def make_id(): return identifier patron, is_new = DelegatedPatronIdentifier.get_one_or_create( self._db, library_uri, patron_identifier, identifier_type, make_id ) return patron def _sample_ecosystem(self): """ Creates an ecosystem of some sample work, pool, edition, and author objects that all know each other. """ # make some authors [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob") bob.family_name, bob.display_name = bob.default_names() [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice") alice.family_name, alice.display_name = alice.default_names() edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI, with_license_pool=True, with_open_access_download=True, authors=[]) edition_std_ebooks.title = u"The Standard Ebooks Title" edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle" edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_git.title = u"The GItenberg Title" edition_git.subtitle = u"The GItenberg Subtitle" edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE) edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_gut.title = u"The GUtenberg Title" edition_gut.subtitle = u"The GUtenberg Subtitle" edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE) work = self._work(presentation_edition=edition_git) for p in pool_gut, pool_std_ebooks: work.license_pools.append(p) work.calculate_presentation() return (work, pool_std_ebooks, pool_git, pool_gut, edition_std_ebooks, edition_git, edition_gut, alice, bob) def print_database_instance(self): """ Calls the class method that examines the current state of the database model (whether it's been committed or not). NOTE: If you set_trace, and hit "continue", you'll start seeing console output right away, without waiting for the whole test to run and the standard output section to display. You can also use nosetest --nocapture. I use: def test_name(self): [code...] set_trace() self.print_database_instance() # TODO: remove before prod [code...] """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.") return DatabaseTest.print_database_class(self._db) return @classmethod def print_database_class(cls, db_connection): """ Prints to the console the entire contents of the database, as the unit test sees it. Exists because unit tests don't persist db information, they create a memory representation of the db state, and then roll the unit test-derived transactions back. So we cannot see what's going on by going into postgres and running selects. This is the in-test alternative to going into postgres. Can be called from model and metadata classes as well as tests. NOTE: The purpose of this method is for debugging. Be careful of leaving it in code and potentially outputting vast tracts of data into your output stream on production. Call like this: set_trace() from testing import ( DatabaseTest, ) _db = Session.object_session(self) DatabaseTest.print_database_class(_db) # TODO: remove before prod """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.") return works = db_connection.query(Work).all() identifiers = db_connection.query(Identifier).all() license_pools = db_connection.query(LicensePool).all() editions = db_connection.query(Edition).all() data_sources = db_connection.query(DataSource).all() representations = db_connection.query(Representation).all() if (not works): print "NO Work found" for wCount, work in enumerate(works): # pipe character at end of line helps see whitespace issues print "Work[%s]=%s|" % (wCount, work) if (not work.license_pools): print " NO Work.LicensePool found" for lpCount, license_pool in enumerate(work.license_pools): print " Work.LicensePool[%s]=%s|" % (lpCount, license_pool) print " Work.presentation_edition=%s|" % work.presentation_edition print "__________________________________________________________________\n" if (not identifiers): print "NO Identifier found" for iCount, identifier in enumerate(identifiers): print "Identifier[%s]=%s|" % (iCount, identifier) print " Identifier.licensed_through=%s|" % identifier.licensed_through print "__________________________________________________________________\n" if (not license_pools): print "NO LicensePool found" for index, license_pool in enumerate(license_pools): print "LicensePool[%s]=%s|" % (index, license_pool) print " LicensePool.work_id=%s|" % license_pool.work_id print " LicensePool.data_source_id=%s|" % license_pool.data_source_id print " LicensePool.identifier_id=%s|" % license_pool.identifier_id print " LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id print " LicensePool.superceded=%s|" % license_pool.superceded print " LicensePool.suppressed=%s|" % license_pool.suppressed print "__________________________________________________________________\n" if (not editions): print "NO Edition found" for index, edition in enumerate(editions): # pipe character at end of line helps see whitespace issues print "Edition[%s]=%s|" % (index, edition) print " Edition.primary_identifier_id=%s|" % edition.primary_identifier_id print " Edition.permanent_work_id=%s|" % edition.permanent_work_id if (edition.data_source): print " Edition.data_source.id=%s|" % edition.data_source.id print " Edition.data_source.name=%s|" % edition.data_source.name else: print " No Edition.data_source." if (edition.license_pool): print " Edition.license_pool.id=%s|" % edition.license_pool.id else: print " No Edition.license_pool." print " Edition.title=%s|" % edition.title print " Edition.author=%s|" % edition.author if (not edition.author_contributors): print " NO Edition.author_contributor found" for acCount, author_contributor in enumerate(edition.author_contributors): print " Edition.author_contributor[%s]=%s|" % (acCount, author_contributor) print "__________________________________________________________________\n" if (not data_sources): print "NO DataSource found" for index, data_source in enumerate(data_sources): print "DataSource[%s]=%s|" % (index, data_source) print " DataSource.id=%s|" % data_source.id print " DataSource.name=%s|" % data_source.name print " DataSource.offers_licenses=%s|" % data_source.offers_licenses print " DataSource.editions=%s|" % data_source.editions print " DataSource.license_pools=%s|" % data_source.license_pools print " DataSource.links=%s|" % data_source.links print "__________________________________________________________________\n" if (not representations): print "NO Representation found" for index, representation in enumerate(representations): print "Representation[%s]=%s|" % (index, representation) print " Representation.id=%s|" % representation.id print " Representation.url=%s|" % representation.url print " Representation.mirror_url=%s|" % representation.mirror_url print " Representation.fetch_exception=%s|" % representation.fetch_exception print " Representation.mirror_exception=%s|" % representation.mirror_exception return def _library(self, name=None, short_name=None): name=name or self._str short_name = short_name or self._str library, ignore = get_one_or_create( self._db, Library, name=name, short_name=short_name, create_method_kwargs=dict(uuid=str(uuid.uuid4())), ) return library def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT, external_account_id=None, url=None, username=None, password=None, data_source_name=None): name = name or self._str collection, ignore = get_one_or_create( self._db, Collection, name=name ) collection.external_account_id = external_account_id integration = collection.create_external_integration(protocol) integration.goal = ExternalIntegration.LICENSE_GOAL integration.url = url integration.username = username integration.password = password if data_source_name: collection.data_source = data_source_name return collection @property def _default_library(self): """A Library that will only be created once throughout a given test. By default, the `_default_collection` will be associated with the default library. """ if not hasattr(self, '_default__library'): self._default__library = self.make_default_library(self._db) return self._default__library @property def _default_collection(self): """A Collection that will only be created once throughout a given test. For most tests there's no need to create a different Collection for every LicensePool. Using self._default_collection instead of calling self.collection() saves time. """ if not hasattr(self, '_default__collection'): self._default__collection = self._default_library.collections[0] return self._default__collection @classmethod def make_default_library(cls, _db): """Ensure that the default library exists in the given database. This can be called by code intended for use in testing but not actually within a DatabaseTest subclass. """ library, ignore = get_one_or_create( _db, Library, create_method_kwargs=dict( uuid=unicode(uuid.uuid4()), name="default", ), short_name="default" ) collection, ignore = get_one_or_create( _db, Collection, name="Default Collection" ) integration = collection.create_external_integration( ExternalIntegration.OPDS_IMPORT ) integration.goal = ExternalIntegration.LICENSE_GOAL if collection not in library.collections: library.collections.append(collection) return library def _catalog(self, name=u"Faketown Public Library"): source, ignore = get_one_or_create(self._db, DataSource, name=name) def _integration_client(self, url=None, shared_secret=None): url = url or self._url secret = shared_secret or u"secret" return get_one_or_create( self._db, IntegrationClient, shared_secret=secret, create_method_kwargs=dict(url=url) )[0] def _subject(self, type, identifier): return get_one_or_create( self._db, Subject, type=type, identifier=identifier )[0] def _classification(self, identifier, subject, data_source, weight=1): return get_one_or_create( self._db, Classification, identifier=identifier, subject=subject, data_source=data_source, weight=weight )[0] def sample_cover_path(self, name): """The path to the sample cover with the given filename.""" base_path = os.path.split(__file__)[0] resource_path = os.path.join(base_path, "tests", "files", "covers") sample_cover_path = os.path.join(resource_path, name) return sample_cover_path def sample_cover_representation(self, name): """A Representation of the sample cover with the given filename.""" sample_cover_path = self.sample_cover_path(name) return self._representation( media_type="image/png", content=open(sample_cover_path).read())[0]
def test_import_column_extra_is_string(app_context: None, session: Session) -> None: """ Test importing a dataset when the column extra is a string. """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.importers.v1.utils import import_dataset from superset.datasets.schemas import ImportV1DatasetSchema from superset.models.core import Database engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") session.add(database) session.flush() dataset_uuid = uuid.uuid4() yaml_config: Dict[str, Any] = { "version": "1.0.0", "table_name": "my_table", "main_dttm_col": "ds", "description": "This is the description", "default_endpoint": None, "offset": -8, "cache_timeout": 3600, "schema": "my_schema", "sql": None, "params": { "remote_id": 64, "database_name": "examples", "import_time": 1606677834, }, "template_params": { "answer": "42", }, "filter_select_enabled": True, "fetch_values_predicate": "foo IN (1, 2)", "extra": '{"warning_markdown": "*WARNING*"}', "uuid": dataset_uuid, "metrics": [{ "metric_name": "cnt", "verbose_name": None, "metric_type": None, "expression": "COUNT(*)", "description": None, "d3format": None, "extra": '{"warning_markdown": null}', "warning_text": None, }], "columns": [{ "column_name": "profit", "verbose_name": None, "is_dttm": False, "is_active": True, "type": "INTEGER", "groupby": False, "filterable": False, "expression": "revenue-expenses", "description": None, "python_date_format": None, "extra": '{"certified_by": "User"}', }], "database_uuid": database.uuid, } schema = ImportV1DatasetSchema() dataset_config = schema.load(yaml_config) dataset_config["database_id"] = database.id sqla_table = import_dataset(session, dataset_config) assert sqla_table.metrics[0].extra == '{"warning_markdown": null}' assert sqla_table.columns[0].extra == '{"certified_by": "User"}' assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
def add(self, entity_class, data): # different signature pylint: disable=W0222 if not IEntity.providedBy(data): # pylint: disable=E1101 self.__run_traversal(entity_class, data, None, RELATION_OPERATIONS.ADD) else: SaSession.add(self, data)
def add_fruit(self, db: Session, order: Order, item: Fruit, quantity: int): order.order_items.append(OrderItem(item=item, quantity=quantity)) db.add(order) db.commit() db.refresh(order) return order
def test_create_none_screamer_chance_default(self): session = Session(connection) webm = WEBM(id_="q") # TODO: add 32 length constrait session.add(webm) session.commit()
def set_total_price(self, db: Session, order: Order, total_price: float): order.total_price = total_price db.add(order) db.commit() db.refresh(order) return order
def test_create_normal_foreign_key(self): session = Session(connection) dirty_webm = DirtyWEBM(md5="a4", webm_id="q") session.add(dirty_webm) session.commit()
session = Session() # Load schema loader/object factory factory = SchemaLoader() factory.loadSchema('schema/Base-schema.xml') factory.loadSchema('schema/Person-schema.xml') # Populate metaclasses, maybe there is a nicer way to do this for (cName, cClass) in factory.getClasses().items(): globals()[cName] = cClass # Create a new object called father father = Employee('cn=Father of Horst, ...') father.givenName = 'The father!' father.age = 1234 session.add(father.getObject()) session.commit() # Create another object which references to the father object tim = Person('cn=Horst Hackpeter, ist voll toll') tim.givenName = 'Horst' tim.sn = 'Hackepeter' tim.age = 2 tim.parent = father.getObject() tim.notes = ['tester', '44', 55, father.getObject()] session.add(tim.getObject()) session.commit() # Walk through database that match 'Horst' for entry in session.query(GOsaDBObject).filter(GOsaDBObject.properties.any(GOsaDBProperty.value == u'Horst')).all(): entry = factory.toObject(entry)
def add_profile_pic(self, db: Session, db_obj: User, image: Image): db_obj.profile_pic = image db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj.profile_pic
#!/usr/bin/python3 """ This script that changes the name of a State object from the database hbtn_0e_6_usa. """ from sqlalchemy import (create_engine) from sqlalchemy.orm.session import Session from model_state import Base, State import sys if __name__ == '__main__': engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format( sys.argv[1], sys.argv[2], sys.argv[3]), pool_pre_ping=True) Base.metadata.create_all(engine) session = Session(bind=engine) a = session.query(State).get(2) # Update the name table with the id = 2 a.name = 'New Mexico' session.add(a) session.commit() session.close()
def create(cls, model: str, obj: dict, *args, session: Session, base: Type[DBT], **kwargs) -> DBRes: db_obj = base(**obj) session.add(db_obj) session.commit() return DBRes(data=db_obj.as_dict())
engine = core.get_database_engine_string() logging.info("Using connection string '%s'" % (engine,)) engine = create_engine(engine, encoding='utf-8', isolation_level="READ COMMITTED") logging.info("Binding session...") session = Session(bind=engine, autocommit = False) # Select the old raw_results sql = "SELECT crawl_id, date_crawled, url, content_type, raw_article_results.status, raw_article_conversions.inserted_id FROM raw_articles JOIN raw_article_results ON raw_article_results.raw_article_id = raw_articles.id JOIN raw_article_conversions ON raw_article_conversions.raw_article_id = raw_articles.id" it = session.execute(sql) for crawl_id, date_crawled, url, content_type, status, inserted_id in it: # Decide if any of these have been comitted sub = session.query(RawArticle).filter_by(crawl_id = crawl_id, url = url, content_type = content_type, date_crawled = date_crawled) try: i = sub.one() logging.info("RawArticle %s has already been processed.", i) continue except NoResultException: pass rbase = RawArticle((crawl_id, (None, None, url, date_crawled, content_type))) rstat = RawArticleResult(None, status) rstat.parent = rbase rrslt = RawArticleResultLink(None, inserted_id) rrslt.parent = rbase session.add(rbase) session.add(rstat) session.add(rrslt) session.commit()
def upgrade(): session = Session(bind=op.get_bind(), expire_on_commit=False) # All IRONMAN orgs need a 3 digit version IRONMAN_system = 'http://pcctc.org/' ironman_org_ids = [(id.id, id._value) for id in Identifier.query.filter( Identifier.system == IRONMAN_system).with_entities( Identifier.id, Identifier._value)] existing_values = [id[1] for id in ironman_org_ids] replacements = {} for io_id, io_value in ironman_org_ids: found = org_pattern.match(io_value) if found: # avoid probs if run again - don't add if already present needed = '146-0{}'.format(found.group(1)) replacements[found.group(1)] = '0{}'.format(found.group(1)) if needed not in existing_values: needed_i = Identifier( use='secondary', system=IRONMAN_system, _value=needed) else: needed_i = Identifier.query.filter( Identifier.system == IRONMAN_system).filter( Identifier._value == needed).one() # add a 3 digit identifier and link with same org oi = OrganizationIdentifier.query.filter( OrganizationIdentifier.identifier_id == io_id).one() needed_oi = OrganizationIdentifier.query.filter( OrganizationIdentifier.organization_id == oi.organization_id).filter( OrganizationIdentifier.identifier == needed_i).first() if not needed_oi: needed_i = session.merge(needed_i) needed_oi = OrganizationIdentifier( organization_id=oi.organization_id, identifier=needed_i) session.add(needed_oi) # All IRONMAN users with a 2 digit ID referencing one of the replaced # values needs a 3 digit version ironman_study_ids = Identifier.query.filter( Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter( Identifier._value.like('170-%')).with_entities( Identifier.id, Identifier._value) for iid, ival in ironman_study_ids: found = study_pattern.match(ival) if found: org_segment = found.group(1) patient_segment = found.group(2) # only add if also one of the new org ids if org_segment not in replacements: continue needed = '170-{}-{}'.format( replacements[org_segment], patient_segment) # add a 3 digit identifier and link with same user(s), # if not already present uis = UserIdentifier.query.filter( UserIdentifier.identifier_id == iid) needed_i = Identifier.query.filter( Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter( Identifier._value == needed).first() if not needed_i: needed_i = Identifier( use='secondary', system=TRUENTH_EXTERNAL_STUDY_SYSTEM, _value=needed) for ui in uis: needed_ui = UserIdentifier.query.filter( UserIdentifier.user_id == ui.user_id).filter( UserIdentifier.identifier == needed_i).first() if not needed_ui: needed_ui = UserIdentifier( user_id=ui.user_id, identifier=needed_i) session.add(needed_ui) session.commit()
def update_replies( remote_replies: List[SDKReply], local_replies: List[Reply], session: Session, data_dir: str ) -> None: """ * Existing replies are updated in the local database. * New replies have an entry created in the local database. * Local replies not returned in the remote replies are deleted from the local database unless they are pending or failed. If a reply references a new journalist username, add them to the database as a new user. """ local_replies_by_uuid = {r.uuid: r for r in local_replies} users: Dict[str, User] = {} source_cache = SourceCache(session) for reply in remote_replies: user = users.get(reply.journalist_uuid) if not user: user = find_or_create_user(reply.journalist_uuid, reply.journalist_username, session) users[reply.journalist_uuid] = user local_reply = local_replies_by_uuid.get(reply.uuid) if local_reply: lazy_setattr(local_reply, "journalist_id", user.id) lazy_setattr(local_reply, "size", reply.size) lazy_setattr(local_reply, "filename", reply.filename) del local_replies_by_uuid[reply.uuid] logger.debug("Updated reply {}".format(reply.uuid)) else: # A new reply to be added to the database. source = source_cache.get(reply.source_uuid) if not source: logger.error(f"No source found for reply {reply.uuid}") continue nr = Reply( uuid=reply.uuid, journalist_id=user.id, source_id=source.id, filename=reply.filename, size=reply.size, ) session.add(nr) # All replies fetched from the server have succeeded in being sent, # so we should delete the corresponding draft locally if it exists. try: draft_reply_db_object = session.query(DraftReply).filter_by(uuid=reply.uuid).one() update_draft_replies( session, draft_reply_db_object.source.id, draft_reply_db_object.timestamp, draft_reply_db_object.file_counter, nr.file_counter, commit=False, ) session.delete(draft_reply_db_object) except NoResultFound: pass # No draft locally stored corresponding to this reply. logger.debug("Added new reply {}".format(reply.uuid)) # The uuids remaining in local_uuids do not exist on the remote server, so # delete the related records. for deleted_reply in local_replies_by_uuid.values(): delete_single_submission_or_reply_on_disk(deleted_reply, data_dir) session.delete(deleted_reply) logger.debug("Deleted reply {}".format(deleted_reply.uuid)) session.commit()
#!/usr/bin/python3 """ This script adds the State object “Louisiana” to the database hbtn_0e_6_usa. """ from sqlalchemy import (create_engine) from sqlalchemy.orm.session import Session from model_state import Base, State import sys if __name__ == '__main__': engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format( sys.argv[1], sys.argv[2], sys.argv[3]), pool_pre_ping=True) Base.metadata.create_all(engine) session = Session(bind=engine) # The object was created obj1 = State(name='Louisiana') # The object created is add session.add(obj1) # The object was save in the data base session.commit() print(obj1.id) session.close()
def _fill_test_database(db: Session) -> NoReturn: """Create dummy users and channels to allow further testing in dev mode.""" testUsers = [] try: for index, username in enumerate(['alice', 'bob', 'carol', 'dave']): user = User(id=uuid.uuid4().bytes, username=username) identity = Identity( provider='dummy', identity_id=str(index), username=username, ) profile = Profile(name=username.capitalize(), avatar_url='/avatar.jpg') user.identities.append(identity) # type: ignore user.profile = profile db.add(user) testUsers.append(user) for channel_index in range(3): channel = Channel( name=f'channel{channel_index}', description=f'Description of channel{channel_index}', private=False, ) for package_index in range(random.randint(5, 10)): package = Package( name=f'package{package_index}', summary=f'package {package_index} summary text', description=f'Description of package{package_index}', ) channel.packages.append(package) # type: ignore test_user = testUsers[random.randint(0, len(testUsers) - 1)] package_member = PackageMember( package=package, channel=channel, user=test_user, role='owner' ) db.add(package_member) if channel_index == 0: package = Package(name='xtensor', description='Description of xtensor') channel.packages.append(package) # type: ignore test_user = testUsers[random.randint(0, len(testUsers) - 1)] package_member = PackageMember( package=package, channel=channel, user=test_user, role='owner' ) db.add(package_member) # create API key key = uuid.uuid4().hex key_user = User(id=uuid.uuid4().bytes) api_key = ApiKey( key=key, description='test API key', user=test_user, owner=test_user ) db.add(api_key) print(f'Test API key created for user "{test_user.username}": {key}') key_package_member = PackageMember( user=key_user, channel_name=channel.name, package_name=package.name, role='maintainer', ) db.add(key_package_member) db.add(channel) channel_member = ChannelMember( channel=channel, user=test_user, role='owner', ) db.add(channel_member) db.commit() finally: db.close()
class Address(Base): __tablename__ = 'address' id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey('user.id')) email_address = Column(String) user = relationship(User, backref='addresses') ################################################################################ # identity set # The session itself acts somewhat like a set-like collection. All items # present may be accessed using the iterator interface. ################################################################################ user = User() session.add(user) for obj in session: print(obj) assert user in session # inspect the current state of an object state = inspect(user) assert state.persistent ################################################################################ # expunge # Expunging removes an object from the session. # Pending instances are sent to the transient state. # Persistent instances from flush() are sent to detached state. Qualify
def session_with_data(session: Session) -> Iterator[Session]: from superset.columns.models import Column from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.datasets.models import Dataset from superset.models.core import Database from superset.models.sql_lab import Query, SavedQuery from superset.tables.models import Table engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") columns = [ TableColumn(column_name="a", type="INTEGER"), ] sqla_table = SqlaTable( table_name="my_sqla_table", columns=columns, metrics=[], database=db, ) query_obj = Query( client_id="foo", database=db, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", select_sql="select * from bar", executed_sql="select * from bar", limit=100, select_as_cta=False, rows=100, error_message="none", results_key="abc", ) saved_query = SavedQuery(database=db, sql="select * from foo") table = Table( name="my_table", schema="my_schema", catalog="my_catalog", database=db, columns=[], ) dataset = Dataset( database=table.database, name="positions", expression=""" SELECT array_agg(array[longitude,latitude]) AS position FROM my_catalog.my_schema.my_table """, tables=[table], columns=[ Column( name="position", expression="array_agg(array[longitude,latitude])", ), ], ) session.add(dataset) session.add(table) session.add(saved_query) session.add(query_obj) session.add(db) session.add(sqla_table) session.flush() yield session
"publication_date": datetime.datetime.strptime("2010-05-05", "%Y-%m-%d"), "pages_count": 240, } engine.connect().execute(insert_stmt, data) session = Session(bind=engine) q = session.query(Book).filter(Book.title == "Essential SQLAlchemy") print q book = q.one() print (book.id, book.title) author = Author(name="Rick Copeland") author.books.append(book) session.add(book) session.flush() #### # select CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END is_novel, count(*) # from BOOK # group by CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END # order by CASE WHEN (BOOK.pages_count > 200) THEN 1 ELSE 0 END # is_novel_column = case([(Book.pages_count > 200, 1)], else_=0) novel_query = ( session.query(is_novel_column.label("is_alias"), count()).group_by(is_novel_column).order_by(is_novel_column) ) print novel_query print novel_query.all()
def upgrade(): session = Session(bind=op.get_bind()) session.add(Exchange(name=BxInThExchange.name, is_active=True, weight=15)) session.flush()
def setup_module(): global transaction, connection, engine engine = create_engine('postgresql:///recordsheet_test') connection = engine.connect() transaction = connection.begin() Base.metadata.create_all(connection) #insert some data inner_tr = connection.begin_nested() ses = Session(connection) # ses.begin_nested() ses.add(dbmodel.Account(name='TEST01', desc='test account 01')) ses.add(dbmodel.Account(name='TEST02', desc='test account 02')) user = dbmodel.User(username='******', name='Test T. User', password=dbapi.new_pw_hash('passtestword'), locked=False) ses.add(user) lockeduser = dbmodel.User(username='******', name='Test T. User', password=dbapi.new_pw_hash('passtestword'), locked=True) ses.add(lockeduser) batch = dbmodel.Batch(user=user) ses.add(batch) jrnl = dbmodel.Journal(memo='test', batch=batch, datetime='2016-06-05 14:09:00-05') ses.add(jrnl) ses.add(dbmodel.Posting(memo="test", amount=100, account_id=1, journal=jrnl)) ses.add(dbmodel.Posting(memo="test", amount=-100, account_id=2, journal=jrnl)) ses.commit() # mock a sessionmaker so all querys are in this transaction dbapi._session = lambda: ses ses.begin_nested() @event.listens_for(ses, "after_transaction_end") def restart_savepoint(session, transaction): if transaction.nested and not transaction._parent.nested: ses.begin_nested()
def verify_integrity(self, session: Session = None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ from airflow.settings import task_instance_mutation_hook dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state != State.RUNNING and not dag.partial: self.log.warning( "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info( "Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr(f"task_instance_created-{task.task_type}", 1, 1) ti = TI(task, self.execution_date) task_instance_mutation_hook(ti) session.add(ti) try: session.flush() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.') self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback()
class DatabaseTest(object): engine = None connection = None @classmethod def get_database_connection(cls): url = Configuration.database_url() engine, connection = SessionManager.initialize(url) return engine, connection @classmethod def setup_class(cls): # Initialize a temporary data directory. cls.engine, cls.connection = cls.get_database_connection() cls.old_data_dir = Configuration.data_directory cls.tmp_data_dir = tempfile.mkdtemp(dir="/tmp") Configuration.instance[Configuration.DATA_DIRECTORY] = cls.tmp_data_dir # Avoid CannotLoadConfiguration errors related to CDN integrations. Configuration.instance[Configuration.INTEGRATIONS] = Configuration.instance.get( Configuration.INTEGRATIONS, {} ) Configuration.instance[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {} @classmethod def teardown_class(cls): # Destroy the database connection and engine. cls.connection.close() cls.engine.dispose() if cls.tmp_data_dir.startswith("/tmp"): logging.debug("Removing temporary directory %s" % cls.tmp_data_dir) shutil.rmtree(cls.tmp_data_dir) else: logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir) Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir def setup(self, mock_search=True): # Create a new connection to the database. self._db = Session(self.connection) self.transaction = self.connection.begin_nested() # Start with a high number so it won't interfere with tests that search for an age or grade self.counter = 2000 self.time_counter = datetime(2014, 1, 1) self.isbns = [ "9780674368279", "0636920028468", "9781936460236", "9780316075978" ] if mock_search: self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", MockExternalSearchIndex) self.search_mock.start() else: self.search_mock = None # TODO: keeping this for now, but need to fix it bc it hits _isbn, # which pops an isbn off the list and messes tests up. so exclude # _ functions from participating. # also attempt to stop nosetest showing docstrings instead of function names. #for name, obj in inspect.getmembers(self): # if inspect.isfunction(obj) and obj.__name__.startswith('test_'): # obj.__doc__ = None def teardown(self): # Close the session. self._db.close() # Roll back all database changes that happened during this # test, whether in the session that was just closed or some # other session. self.transaction.rollback() # Remove any database objects cached in the model classes but # associated with the now-rolled-back session. Collection.reset_cache() ConfigurationSetting.reset_cache() DataSource.reset_cache() DeliveryMechanism.reset_cache() ExternalIntegration.reset_cache() Genre.reset_cache() Library.reset_cache() # Also roll back any record of those changes in the # Configuration instance. for key in [ Configuration.SITE_CONFIGURATION_LAST_UPDATE, Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE ]: if key in Configuration.instance: del(Configuration.instance[key]) if self.search_mock: self.search_mock.stop() def time_eq(self, a, b): "Assert that two times are *approximately* the same -- within 2 seconds." if a < b: delta = b-a else: delta = a-b total_seconds = delta.total_seconds() assert (total_seconds < 2), ("Delta was too large: %.2f seconds." % total_seconds) def shortDescription(self): return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2. @property def _id(self): self.counter += 1 return self.counter @property def _str(self): return unicode(self._id) @property def _time(self): v = self.time_counter self.time_counter = self.time_counter + timedelta(days=1) return v @property def _isbn(self): return self.isbns.pop() @property def _url(self): return "http://foo.com/" + self._str def _patron(self, external_identifier=None, library=None): external_identifier = external_identifier or self._str library = library or self._default_library return get_one_or_create( self._db, Patron, external_identifier=external_identifier, library=library )[0] def _contributor(self, sort_name=None, name=None, **kw_args): name = sort_name or name or self._str return get_one_or_create(self._db, Contributor, sort_name=unicode(name), **kw_args) def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None): if foreign_id: id = foreign_id else: id = self._str return Identifier.for_foreign_id(self._db, identifier_type, id)[0] def _edition(self, data_source_name=DataSource.GUTENBERG, identifier_type=Identifier.GUTENBERG_ID, with_license_pool=False, with_open_access_download=False, title=None, language="eng", authors=None, identifier_id=None, series=None, collection=None, publicationDate=None ): id = identifier_id or self._str source = DataSource.lookup(self._db, data_source_name) wr = Edition.for_foreign_id( self._db, source, identifier_type, id)[0] if not title: title = self._str wr.title = unicode(title) wr.medium = Edition.BOOK_MEDIUM if series: wr.series = series if language: wr.language = language if authors is None: authors = self._str if isinstance(authors, basestring): authors = [authors] if authors != []: wr.add_contributor(unicode(authors[0]), Contributor.PRIMARY_AUTHOR_ROLE) wr.author = unicode(authors[0]) for author in authors[1:]: wr.add_contributor(unicode(author), Contributor.AUTHOR_ROLE) if publicationDate: wr.published = publicationDate if with_license_pool or with_open_access_download: pool = self._licensepool( wr, data_source_name=data_source_name, with_open_access_download=with_open_access_download, collection=collection ) pool.set_presentation_edition() return wr, pool return wr def _work(self, title=None, authors=None, genre=None, language=None, audience=None, fiction=True, with_license_pool=False, with_open_access_download=False, quality=0.5, series=None, presentation_edition=None, collection=None, data_source_name=None): """Create a Work. For performance reasons, this method does not generate OPDS entries or calculate a presentation edition for the new Work. Tests that rely on this information being present should call _slow_work() instead, which takes more care to present the sort of Work that would be created in a real environment. """ pools = [] if with_open_access_download: with_license_pool = True language = language or "eng" title = unicode(title or self._str) audience = audience or Classifier.AUDIENCE_ADULT if audience == Classifier.AUDIENCE_CHILDREN and not data_source_name: # TODO: This is necessary because Gutenberg's childrens books # get filtered out at the moment. data_source_name = DataSource.OVERDRIVE elif not data_source_name: data_source_name = DataSource.GUTENBERG if fiction is None: fiction = True new_edition = False if not presentation_edition: new_edition = True presentation_edition = self._edition( title=title, language=language, authors=authors, with_license_pool=with_license_pool, with_open_access_download=with_open_access_download, data_source_name=data_source_name, series=series, collection=collection, ) if with_license_pool: presentation_edition, pool = presentation_edition if with_open_access_download: pool.open_access = True pools = [pool] else: pools = presentation_edition.license_pools work, ignore = get_one_or_create( self._db, Work, create_method_kwargs=dict( audience=audience, fiction=fiction, quality=quality), id=self._id) if genre: if not isinstance(genre, Genre): genre, ignore = Genre.lookup(self._db, genre, autocreate=True) work.genres = [genre] work.random = 0.5 work.set_presentation_edition(presentation_edition) if pools: # make sure the pool's presentation_edition is set, # bc loan tests assume that. if not work.license_pools: for pool in pools: work.license_pools.append(pool) for pool in pools: pool.set_presentation_edition() # This is probably going to be used in an OPDS feed, so # fake that the work is presentation ready. work.presentation_ready = True work.calculate_opds_entries(verbose=False) return work def add_to_materialized_view(self, works, true_opds=False): """Make sure all the works in `works` show up in the materialized view. :param true_opds: Generate real OPDS entries for each each work, rather than faking it. """ if not isinstance(works, list): works = [works] for work in works: if true_opds: work.calculate_opds_entries(verbose=False) else: work.presentation_ready = True work.simple_opds_entry = "<entry>an entry</entry>" self._db.commit() SessionManager.refresh_materialized_views(self._db) def _lane(self, display_name=None, library=None, parent=None, genres=None, languages=None, fiction=None ): display_name = display_name or self._str library = library or self._default_library lane, is_new = create( self._db, Lane, library=library, parent=parent, display_name=display_name, fiction=fiction ) if is_new and parent: lane.priority = len(parent.sublanes)-1 if genres: if not isinstance(genres, list): genres = [genres] for genre in genres: if isinstance(genre, basestring): genre, ignore = Genre.lookup(self._db, genre) lane.genres.append(genre) if languages: if not isinstance(languages, list): languages = [languages] lane.languages = languages return lane def _slow_work(self, *args, **kwargs): """Create a work that closely resembles one that might be found in the wild. This is significantly slower than _work() but more reliable. """ work = self._work(*args, **kwargs) work.calculate_presentation_edition() work.calculate_opds_entries(verbose=False) return work def _add_generic_delivery_mechanism(self, license_pool): """Give a license pool a generic non-open-access delivery mechanism.""" data_source = license_pool.data_source identifier = license_pool.identifier content_type = Representation.EPUB_MEDIA_TYPE drm_scheme = DeliveryMechanism.NO_DRM LicensePoolDeliveryMechanism.set( data_source, identifier, content_type, drm_scheme, RightsStatus.IN_COPYRIGHT ) def _coverage_record(self, edition, coverage_source, operation=None, status=CoverageRecord.SUCCESS, collection=None, exception=None, ): if isinstance(edition, Identifier): identifier = edition else: identifier = edition.primary_identifier record, ignore = get_one_or_create( self._db, CoverageRecord, identifier=identifier, data_source=coverage_source, operation=operation, collection=collection, create_method_kwargs = dict( timestamp=datetime.utcnow(), status=status, exception=exception, ) ) return record def _work_coverage_record(self, work, operation=None, status=CoverageRecord.SUCCESS): record, ignore = get_one_or_create( self._db, WorkCoverageRecord, work=work, operation=operation, create_method_kwargs = dict( timestamp=datetime.utcnow(), status=status, ) ) return record def _licensepool(self, edition, open_access=True, data_source_name=DataSource.GUTENBERG, with_open_access_download=False, set_edition_as_presentation=False, collection=None): source = DataSource.lookup(self._db, data_source_name) if not edition: edition = self._edition(data_source_name) collection = collection or self._default_collection pool, ignore = get_one_or_create( self._db, LicensePool, create_method_kwargs=dict( open_access=open_access), identifier=edition.primary_identifier, data_source=source, collection=collection, availability_time=datetime.utcnow() ) if set_edition_as_presentation: pool.presentation_edition = edition if with_open_access_download: pool.open_access = True url = "http://foo.com/" + self._str media_type = MediaTypes.EPUB_MEDIA_TYPE link, new = pool.identifier.add_link( Hyperlink.OPEN_ACCESS_DOWNLOAD, url, source, media_type ) # Add a DeliveryMechanism for this download pool.set_delivery_mechanism( media_type, DeliveryMechanism.NO_DRM, RightsStatus.GENERIC_OPEN_ACCESS, link.resource, ) representation, is_new = self._representation( url, media_type, "Dummy content", mirrored=True) link.resource.representation = representation else: # Add a DeliveryMechanism for this licensepool pool.set_delivery_mechanism( MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM, RightsStatus.UNKNOWN, None ) pool.licenses_owned = pool.licenses_available = 1 return pool def _license(self, pool, identifier=None, checkout_url=None, status_url=None, expires=None, remaining_checkouts=None, concurrent_checkouts=None): identifier = identifier or self._str checkout_url = checkout_url or self._str status_url = status_url or self._str license, ignore = get_one_or_create( self._db, License, identifier=identifier, license_pool=pool, checkout_url=checkout_url, status_url=status_url, expires=expires, remaining_checkouts=remaining_checkouts, concurrent_checkouts=concurrent_checkouts, ) return license def _representation(self, url=None, media_type=None, content=None, mirrored=False): url = url or "http://foo.com/" + self._str repr, is_new = get_one_or_create( self._db, Representation, url=url) repr.media_type = media_type if media_type and content: repr.content = content repr.fetched_at = datetime.utcnow() if mirrored: repr.mirror_url = "http://foo.com/" + self._str repr.mirrored_at = datetime.utcnow() return repr, is_new def _customlist(self, foreign_identifier=None, name=None, data_source_name=DataSource.NYT, num_entries=1, entries_exist_as_works=True ): data_source = DataSource.lookup(self._db, data_source_name) foreign_identifier = foreign_identifier or self._str now = datetime.utcnow() customlist, ignore = get_one_or_create( self._db, CustomList, create_method_kwargs=dict( created=now, updated=now, name=name or self._str, description=self._str, ), data_source=data_source, foreign_identifier=foreign_identifier ) editions = [] for i in range(num_entries): if entries_exist_as_works: work = self._work(with_open_access_download=True) edition = work.presentation_edition else: edition = self._edition( data_source_name, title="Item %s" % i) edition.permanent_work_id="Permanent work ID %s" % self._str customlist.add_entry( edition, "Annotation %s" % i, first_appearance=now) editions.append(edition) return customlist, editions def _complaint(self, license_pool, type, source, detail, resolved=None): complaint, is_new = Complaint.register( license_pool, type, source, detail, resolved ) return complaint def _credential(self, data_source_name=DataSource.GUTENBERG, type=None, patron=None): data_source = DataSource.lookup(self._db, data_source_name) type = type or self._str patron = patron or self._patron() credential, is_new = Credential.persistent_token_create( self._db, data_source, type, patron ) return credential def _external_integration(self, protocol, goal=None, settings=None, libraries=None, **kwargs ): integration = None if not libraries: integration, ignore = get_one_or_create( self._db, ExternalIntegration, protocol=protocol, goal=goal ) else: if not isinstance(libraries, list): libraries = [libraries] # Try to find an existing integration for one of the given # libraries. for library in libraries: integration = ExternalIntegration.lookup( self._db, protocol, goal, library=libraries[0] ) if integration: break if not integration: # Otherwise, create a brand new integration specifically # for the library. integration = ExternalIntegration( protocol=protocol, goal=goal, ) integration.libraries.extend(libraries) self._db.add(integration) for attr, value in kwargs.items(): setattr(integration, attr, value) settings = settings or dict() for key, value in settings.items(): integration.set_setting(key, value) return integration def _delegated_patron_identifier( self, library_uri=None, patron_identifier=None, identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID, identifier=None ): """Create a sample DelegatedPatronIdentifier""" library_uri = library_uri or self._url patron_identifier = patron_identifier or self._str if callable(identifier): make_id = identifier else: if not identifier: identifier = self._str def make_id(): return identifier patron, is_new = DelegatedPatronIdentifier.get_one_or_create( self._db, library_uri, patron_identifier, identifier_type, make_id ) return patron def _sample_ecosystem(self): """ Creates an ecosystem of some sample work, pool, edition, and author objects that all know each other. """ # make some authors [bob], ignore = Contributor.lookup(self._db, u"Bitshifter, Bob") bob.family_name, bob.display_name = bob.default_names() [alice], ignore = Contributor.lookup(self._db, u"Adder, Alice") alice.family_name, alice.display_name = alice.default_names() edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI, with_license_pool=True, with_open_access_download=True, authors=[]) edition_std_ebooks.title = u"The Standard Ebooks Title" edition_std_ebooks.subtitle = u"The Standard Ebooks Subtitle" edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_git.title = u"The GItenberg Title" edition_git.subtitle = u"The GItenberg Subtitle" edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE) edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE) edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID, with_license_pool=True, with_open_access_download=True, authors=[]) edition_gut.title = u"The GUtenberg Title" edition_gut.subtitle = u"The GUtenberg Subtitle" edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE) work = self._work(presentation_edition=edition_git) for p in pool_gut, pool_std_ebooks: work.license_pools.append(p) work.calculate_presentation() return (work, pool_std_ebooks, pool_git, pool_gut, edition_std_ebooks, edition_git, edition_gut, alice, bob) def print_database_instance(self): """ Calls the class method that examines the current state of the database model (whether it's been committed or not). NOTE: If you set_trace, and hit "continue", you'll start seeing console output right away, without waiting for the whole test to run and the standard output section to display. You can also use nosetest --nocapture. I use: def test_name(self): [code...] set_trace() self.print_database_instance() # TODO: remove before prod [code...] """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.") return DatabaseTest.print_database_class(self._db) return @classmethod def print_database_class(cls, db_connection): """ Prints to the console the entire contents of the database, as the unit test sees it. Exists because unit tests don't persist db information, they create a memory representation of the db state, and then roll the unit test-derived transactions back. So we cannot see what's going on by going into postgres and running selects. This is the in-test alternative to going into postgres. Can be called from model and metadata classes as well as tests. NOTE: The purpose of this method is for debugging. Be careful of leaving it in code and potentially outputting vast tracts of data into your output stream on production. Call like this: set_trace() from testing import ( DatabaseTest, ) _db = Session.object_session(self) DatabaseTest.print_database_class(_db) # TODO: remove before prod """ if not 'TESTING' in os.environ: # we are on production, abort, abort! logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.") return works = db_connection.query(Work).all() identifiers = db_connection.query(Identifier).all() license_pools = db_connection.query(LicensePool).all() editions = db_connection.query(Edition).all() data_sources = db_connection.query(DataSource).all() representations = db_connection.query(Representation).all() if (not works): print "NO Work found" for wCount, work in enumerate(works): # pipe character at end of line helps see whitespace issues print "Work[%s]=%s|" % (wCount, work) if (not work.license_pools): print " NO Work.LicensePool found" for lpCount, license_pool in enumerate(work.license_pools): print " Work.LicensePool[%s]=%s|" % (lpCount, license_pool) print " Work.presentation_edition=%s|" % work.presentation_edition print "__________________________________________________________________\n" if (not identifiers): print "NO Identifier found" for iCount, identifier in enumerate(identifiers): print "Identifier[%s]=%s|" % (iCount, identifier) print " Identifier.licensed_through=%s|" % identifier.licensed_through print "__________________________________________________________________\n" if (not license_pools): print "NO LicensePool found" for index, license_pool in enumerate(license_pools): print "LicensePool[%s]=%s|" % (index, license_pool) print " LicensePool.work_id=%s|" % license_pool.work_id print " LicensePool.data_source_id=%s|" % license_pool.data_source_id print " LicensePool.identifier_id=%s|" % license_pool.identifier_id print " LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id print " LicensePool.superceded=%s|" % license_pool.superceded print " LicensePool.suppressed=%s|" % license_pool.suppressed print "__________________________________________________________________\n" if (not editions): print "NO Edition found" for index, edition in enumerate(editions): # pipe character at end of line helps see whitespace issues print "Edition[%s]=%s|" % (index, edition) print " Edition.primary_identifier_id=%s|" % edition.primary_identifier_id print " Edition.permanent_work_id=%s|" % edition.permanent_work_id if (edition.data_source): print " Edition.data_source.id=%s|" % edition.data_source.id print " Edition.data_source.name=%s|" % edition.data_source.name else: print " No Edition.data_source." if (edition.license_pool): print " Edition.license_pool.id=%s|" % edition.license_pool.id else: print " No Edition.license_pool." print " Edition.title=%s|" % edition.title print " Edition.author=%s|" % edition.author if (not edition.author_contributors): print " NO Edition.author_contributor found" for acCount, author_contributor in enumerate(edition.author_contributors): print " Edition.author_contributor[%s]=%s|" % (acCount, author_contributor) print "__________________________________________________________________\n" if (not data_sources): print "NO DataSource found" for index, data_source in enumerate(data_sources): print "DataSource[%s]=%s|" % (index, data_source) print " DataSource.id=%s|" % data_source.id print " DataSource.name=%s|" % data_source.name print " DataSource.offers_licenses=%s|" % data_source.offers_licenses print " DataSource.editions=%s|" % data_source.editions print " DataSource.license_pools=%s|" % data_source.license_pools print " DataSource.links=%s|" % data_source.links print "__________________________________________________________________\n" if (not representations): print "NO Representation found" for index, representation in enumerate(representations): print "Representation[%s]=%s|" % (index, representation) print " Representation.id=%s|" % representation.id print " Representation.url=%s|" % representation.url print " Representation.mirror_url=%s|" % representation.mirror_url print " Representation.fetch_exception=%s|" % representation.fetch_exception print " Representation.mirror_exception=%s|" % representation.mirror_exception return def _library(self, name=None, short_name=None): name=name or self._str short_name = short_name or self._str library, ignore = get_one_or_create( self._db, Library, name=name, short_name=short_name, create_method_kwargs=dict(uuid=str(uuid.uuid4())), ) return library def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT, external_account_id=None, url=None, username=None, password=None, data_source_name=None): name = name or self._str collection, ignore = get_one_or_create( self._db, Collection, name=name ) collection.external_account_id = external_account_id integration = collection.create_external_integration(protocol) integration.goal = ExternalIntegration.LICENSE_GOAL integration.url = url integration.username = username integration.password = password if data_source_name: collection.data_source = data_source_name return collection @property def _default_library(self): """A Library that will only be created once throughout a given test. By default, the `_default_collection` will be associated with the default library. """ if not hasattr(self, '_default__library'): self._default__library = self.make_default_library(self._db) return self._default__library @property def _default_collection(self): """A Collection that will only be created once throughout a given test. For most tests there's no need to create a different Collection for every LicensePool. Using self._default_collection instead of calling self.collection() saves time. """ if not hasattr(self, '_default__collection'): self._default__collection = self._default_library.collections[0] return self._default__collection @classmethod def make_default_library(cls, _db): """Ensure that the default library exists in the given database. This can be called by code intended for use in testing but not actually within a DatabaseTest subclass. """ library, ignore = get_one_or_create( _db, Library, create_method_kwargs=dict( uuid=unicode(uuid.uuid4()), name="default", ), short_name="default" ) collection, ignore = get_one_or_create( _db, Collection, name="Default Collection" ) integration = collection.create_external_integration( ExternalIntegration.OPDS_IMPORT ) integration.goal = ExternalIntegration.LICENSE_GOAL if collection not in library.collections: library.collections.append(collection) return library def _catalog(self, name=u"Faketown Public Library"): source, ignore = get_one_or_create(self._db, DataSource, name=name) def _integration_client(self, url=None, shared_secret=None): url = url or self._url secret = shared_secret or u"secret" return get_one_or_create( self._db, IntegrationClient, shared_secret=secret, create_method_kwargs=dict(url=url) )[0] def _subject(self, type, identifier): return get_one_or_create( self._db, Subject, type=type, identifier=identifier )[0] def _classification(self, identifier, subject, data_source, weight=1): return get_one_or_create( self._db, Classification, identifier=identifier, subject=subject, data_source=data_source, weight=weight )[0] def sample_cover_path(self, name): """The path to the sample cover with the given filename.""" base_path = os.path.split(__file__)[0] resource_path = os.path.join(base_path, "tests", "files", "covers") sample_cover_path = os.path.join(resource_path, name) return sample_cover_path def sample_cover_representation(self, name): """A Representation of the sample cover with the given filename.""" sample_cover_path = self.sample_cover_path(name) return self._representation( media_type="image/png", content=open(sample_cover_path).read())[0]
def merge_conn(conn, session: Session = NEW_SESSION): """Add new Connection.""" if not session.query(Connection).filter( Connection.conn_id == conn.conn_id).first(): session.add(conn) session.commit()
class testTradingCenter(unittest.TestCase): def setUp(self): self.trans = connection.begin() self.session = Session(connection) currency = Currency(name='Pesos', code='ARG') self.exchange = Exchange(name='Merval', code='MERV', currency=currency) self.owner = Owner(name='poor owner') self.broker = Broker(name='broker1') self.account = Account(owner=self.owner, broker=self.broker) self.account.deposit(Money(amount=10000, currency=currency)) def tearDown(self): self.trans.rollback() self.session.close() def test_open_orders_by_order_id(self): stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10) order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) order=tc.open_order_by_id(order1.id) self.assertEquals(order1, order) order=tc.open_order_by_id(100) self.assertEquals(None, order) def testGetOpenOrdersBySymbol(self): stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10) order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) orders=tc.open_orders_by_symbol('symbol') self.assertEquals([order1, order2], list(orders)) def testCancelOrder(self): stock=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=stock, price=13.2, share=10) order2=BuyOrder(account=self.account, security=stock, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) order1.cancel() self.assertEquals([order2], tc.open_orders) self.assertEquals([order1], tc.cancel_orders) self.assertEquals(CancelOrderStage, type(order1.current_stage)) order2.cancel() self.assertEquals([], tc.open_orders) self.assertEquals([order1, order2], tc.cancel_orders) def testCancelAllOpenOrders(self): security=Stock(symbol='symbol', description='a stock', ISIN='US123456789', exchange=self.exchange) order1=BuyOrder(account=self.account, security=security, price=13.2, share=10) order2=BuyOrder(account=self.account, security=security, price=13.25, share=10) self.session.add(order1) self.session.add(order2) self.session.commit() tc=TradingCenter(self.session) tc.cancel_all_open_orders() self.assertEquals([], tc.open_orders) def testConsume(self): pass def testPostConsume(self): pass def testCreateAccountWithMetrix(self): pass
def add(self, entity): self.begin() SaSession.add(self, entity) self.commit()
# Database connection engine = core.get_database_engine_string() logging.info("Using connection string '%s'" % (engine,)) engine = create_engine(engine, encoding='utf-8', isolation_level = 'READ COMMITTED') logging.info("Binding session...") session = Session(bind=engine, autocommit = False) # # Initial query object it = session.query(UserQuery).filter_by(text = query_text) try: q = it.one() except NoResultFound: q = UserQuery(query_text) session.add(q) session.commit() assert q.id is not None logging.info("Resolving query id %d with text '%s'", q.id, q.text) # # Query keyword resolution keywords = set([]) using_keywords = False raw_keyword_count = 0 for k in q.get_keywords(q.text): logging.debug("keyword: %s", k) using_keywords = True raw_keyword_count += 1 it = session.query(Keyword).filter(Keyword.word.like("%s" % (k,))) keywords.update(it)
def draw(*, group_id: uuid.UUID, filtration_method: str, nullify_radial_velocity: bool, with_luminosity_function: bool, with_velocities_vs_magnitude: bool, with_velocity_clouds: bool, lepine_criterion: bool, heatmaps_axes: str, with_toomre_diagram: bool, with_ugriz_diagrams: bool, desired_stars_count: int, session: Session) -> None: entities = star_query_entities( filtration_method=filtration_method, nullify_radial_velocity=nullify_radial_velocity, lepine_criterion=lepine_criterion, with_luminosity_function=with_luminosity_function, with_velocities_vs_magnitude=with_velocities_vs_magnitude, with_velocity_clouds=with_velocity_clouds, heatmaps_axes=heatmaps_axes, with_toomre_diagram=with_toomre_diagram, with_ugriz_diagrams=with_ugriz_diagrams) if not entities: raise ValueError('No plotting options were chosen') query = (session.query(Star).filter(Star.group_id == group_id)) if desired_stars_count: query = (query.order_by(func.random()).limit(desired_stars_count)) query = query.with_entities(*entities) statement = query.statement stars = pd.read_sql_query(sql=statement, con=session.get_bind(), index_col='id') filtration_functions = stars_filtration_functions(method=filtration_method) eliminations_counter = stars_eliminations_counter( stars, filtration_functions=filtration_functions, group_id=group_id) session.add(eliminations_counter) session.commit() stars = filtered_stars(stars, filtration_functions=filtration_functions) if nullify_radial_velocity: set_radial_velocity_to_zero(stars) if with_luminosity_function: luminosity_function.plot(stars=stars) if with_velocities_vs_magnitude: if lepine_criterion: velocities_vs_magnitude.plot_lepine_case(stars=stars) else: velocities_vs_magnitude.plot(stars=stars) if with_velocity_clouds: if lepine_criterion: velocity_clouds.plot_lepine_case(stars=stars) else: velocity_clouds.plot(stars=stars) if heatmaps_axes: heatmaps.plot(stars=stars, axes=heatmaps_axes) if with_toomre_diagram: toomre_diagram.plot(stars=stars) if with_ugriz_diagrams: ugriz_diagrams.plot(stars=stars)
class InboxSession(object): """ Inbox custom ORM (with SQLAlchemy compatible API). Parameters ---------- engine : <sqlalchemy.engine.Engine> A configured database engine to use for this session versioned : bool Do you want to enable the transaction log? ignore_soft_deletes : bool Whether or not to ignore soft-deleted objects in query results. namespace_id : int Namespace to limit query results with. """ def __init__(self, engine, versioned=True, ignore_soft_deletes=True, namespace_id=None): # TODO: support limiting on namespaces assert engine, "Must set the database engine" args = dict(bind=engine, autoflush=True, autocommit=False) self.ignore_soft_deletes = ignore_soft_deletes if ignore_soft_deletes: args['query_cls'] = InboxQuery self._session = Session(**args) if versioned: from inbox.models.transaction import create_revisions @event.listens_for(self._session, 'after_flush') def after_flush(session, flush_context): """ Hook to log revision snapshots. Must be post-flush in order to grab object IDs on new objects. """ create_revisions(session) def query(self, *args, **kwargs): q = self._session.query(*args, **kwargs) if self.ignore_soft_deletes: return q.options(IgnoreSoftDeletesOption()) else: return q def add(self, instance): if not self.ignore_soft_deletes or not instance.is_deleted: self._session.add(instance) else: raise Exception("Why are you adding a deleted object?") def add_all(self, instances): if True not in [i.is_deleted for i in instances] or \ not self.ignore_soft_deletes: self._session.add_all(instances) else: raise Exception("Why are you adding a deleted object?") def delete(self, instance): if self.ignore_soft_deletes: instance.mark_deleted() # just to make sure self._session.add(instance) else: self._session.delete(instance) def begin(self): self._session.begin() def commit(self): self._session.commit() def rollback(self): self._session.rollback() def flush(self): self._session.flush() def close(self): self._session.close() def expunge(self, obj): self._session.expunge(obj) def merge(self, obj): return self._session.merge(obj) @property def no_autoflush(self): return self._session.no_autoflush
def call_api(self, api_client: API, session: Session) -> str: ''' Override ApiJob. Encrypt the reply and send it to the server. If the call is successful, add it to the local database and return the reply uuid string. Otherwise raise a SendReplyJobException so that we can return the reply uuid. ''' try: # If the reply has already made it to the server but we didn't get a 201 response back, # then a reply with self.reply_uuid will exist in the replies table. reply_db_object = session.query(Reply).filter_by( uuid=self.reply_uuid).one_or_none() if reply_db_object: logger.debug( 'Reply {} has already been sent successfully'.format( self.reply_uuid)) return reply_db_object.uuid # If the draft does not exist because it was deleted locally then do not send the # message to the source. draft_reply_db_object = session.query(DraftReply).filter_by( uuid=self.reply_uuid).one_or_none() if not draft_reply_db_object: raise Exception('Draft reply {} does not exist'.format( self.reply_uuid)) # If the source was deleted locally then do not send the message and delete the draft. source = session.query(Source).filter_by( uuid=self.source_uuid).one_or_none() if not source: session.delete(draft_reply_db_object) session.commit() raise Exception('Source {} does not exists'.format( self.source_uuid)) # Send the draft reply to the source encrypted_reply = self.gpg.encrypt_to_source( self.source_uuid, self.message) sdk_reply = self._make_call(encrypted_reply, api_client) # Create a new reply object with an updated filename and file counter interaction_count = source.interaction_count + 1 filename = '{}-{}-reply.gpg'.format(interaction_count, source.journalist_designation) reply_db_object = Reply( uuid=self.reply_uuid, source_id=source.id, filename=filename, journalist_id=api_client.token_journalist_uuid, content=self.message, is_downloaded=True, is_decrypted=True) new_file_counter = int(sdk_reply.filename.split('-')[0]) reply_db_object.file_counter = new_file_counter reply_db_object.filename = sdk_reply.filename # Update following draft replies for the same source to reflect the new reply count draft_file_counter = draft_reply_db_object.file_counter draft_timestamp = draft_reply_db_object.timestamp update_draft_replies(session, source.id, draft_timestamp, draft_file_counter, new_file_counter) # Add reply to replies table and increase the source interaction count by 1 and delete # the draft reply. session.add(reply_db_object) source.interaction_count += 1 session.add(source) session.delete(draft_reply_db_object) session.commit() return reply_db_object.uuid except (RequestTimeoutError, ServerConnectionError) as e: message = "Failed to send reply for source {id} due to Exception: {error}".format( id=self.source_uuid, error=e) raise SendReplyJobTimeoutError(message, self.reply_uuid) except Exception as e: # Continue to store the draft reply message = ''' Failed to send reply {uuid} for source {id} due to Exception: {error} '''.format(uuid=self.reply_uuid, id=self.source_uuid, error=e) self._set_status_to_failed(session) raise SendReplyJobError(message, self.reply_uuid)
def __add(self, entity, path): # pylint: disable=W0613 if len(path) == 0: SaSession.add(self, entity)
def add_video(session: Session, args): vid = session.query(Video).filter(Video.title == args.title).first() if vid: print("Video with same title already exists, checking status...") if not args.force and (vid.youtube_status or vid.twitch_status or vid.processing_status): print("Video processing has already started, aborting update.") return -1 else: print("Updating existing video") else: vid = Video() session.add(vid) vid.date_added = datetime.datetime.now(datetime.timezone.utc) vid.title = args.title vid.description = args.desc vid.start_segment = int(args.start) vid.end_segment = int(args.end) diff = vid.end_segment - vid.start_segment if (diff < 0 or diff > 24 * 60 * 60) and not args.force: print("Video would be over 24 hours long, aborting!") return -1 kws = str(get_setting(session, "common_keywords")).split(",") kws.extend(str(args.keywords).split(",")) vid.keywords = ",".join(kws) if get_setting(session, "upload_to_twitch"): vid.do_twitch = True else: vid.do_twitch = False if get_setting(session, "upload_to_youtube"): vid.do_youtube = True else: vid.do_youtube = False if get_setting(session, "youtube_public"): vid.youtube_public = True else: vid.youtube_public = False # TODO: Potentially kill any running upload tasks vid.youtube_status = "" vid.twitch_status = "" vid.processing_status = str(get_setting(session, "init_proc_status")) vid.processing_done = False vid.done_twitch = False vid.done_youtube = False if args.game: vid.game = args.game else: vid.game = None if vid.do_youtube and not vid.youtube_pubdate and get_setting(session, "youtube_schedule"): last_pubdate = get_last_publish_date(session) offset = get_setting(session, "schedule_offset_hours") vid.youtube_pubdate = last_pubdate + datetime.timedelta(hours=offset) session.commit() print("Added as Video %s" % vid.id) return 0