def test_multiple_key_replaced_by_update(self): users, User = self.tables.users, self.classes.User mapper(User, users) u1 = User(name='u1') u2 = User(name='u2') u3 = User(name='u3') s = Session() s.add_all([u1, u2, u3]) s.commit() s.delete(u1) s.delete(u2) s.flush() u3.name = 'u1' s.flush() u3.name = 'u2' s.flush() s.rollback() assert u1 in s assert u2 in s assert u3 in s assert s.identity_map[(User, ('u1',))] is u1 assert s.identity_map[(User, ('u2',))] is u2 assert s.identity_map[(User, ('u3',))] is u3
def test_set_composite_attrs_via_selectable(self): Values, CustomValues, values, Descriptions, descriptions = (self.classes.Values, self.classes.CustomValues, self.tables.values, self.classes.Descriptions, self.tables.descriptions) session = Session() d = Descriptions( custom_descriptions = CustomValues('Color', 'Number'), values =[ Values(custom_values = CustomValues('Red', '5')), Values(custom_values=CustomValues('Blue', '1')) ] ) session.add(d) session.commit() eq_( testing.db.execute(descriptions.select()).fetchall(), [(1, 'Color', 'Number')] ) eq_( testing.db.execute(values.select()).fetchall(), [(1, 1, 'Red', '5'), (2, 1, 'Blue', '1')] )
def test_is_modified_passive_on(self): User, Address = self._default_mapping_fixture() s = Session() u = User(name='fred', addresses=[Address(email_address='foo')]) s.add(u) s.commit() u.id def go(): assert not s.is_modified(u, passive=True) self.assert_sql_count( testing.db, go, 0 ) u.name = 'newname' def go(): assert s.is_modified(u, passive=True) self.assert_sql_count( testing.db, go, 0 )
def test_continue_flushing_on_commit(self): """test that post-flush actions get flushed also if we're in commit()""" users, User = self.tables.users, self.classes.User mapper(User, users) sess = Session() to_flush = [User(name='ed'), User(name='jack'), User(name='wendy')] @event.listens_for(sess, "after_flush_postexec") def add_another_user(session, ctx): if to_flush: session.add(to_flush.pop(0)) x = [1] @event.listens_for(sess, "after_commit") # noqa def add_another_user(session): x[0] += 1 sess.add(to_flush.pop()) sess.commit() eq_(x, [2]) eq_( sess.scalar(select([func.count(users.c.id)])), 3 )
def test_explicit_expunge_deleted(self): users, User = self.tables.users, self.classes.User mapper(User, users) sess = Session() sess.add(User(name='x')) sess.commit() u1 = sess.query(User).first() sess.delete(u1) sess.flush() assert was_deleted(u1) assert u1 not in sess assert object_session(u1) is sess sess.expunge(u1) assert was_deleted(u1) assert u1 not in sess assert object_session(u1) is None sess.rollback() assert was_deleted(u1) assert u1 not in sess assert object_session(u1) is None
def test_auto_detach_on_gc_session(self): users, User = self.tables.users, self.classes.User mapper(User, users) sess = Session() u1 = User(name='u1') sess.add(u1) sess.commit() # can't add u1 to Session, # already belongs to u2 s2 = Session() assert_raises_message( sa.exc.InvalidRequestError, r".*is already attached to session", s2.add, u1 ) # garbage collect sess del sess gc_collect() # s2 lets it in now despite u1 having # session_key s2.add(u1) assert u1 in s2
def test_is_modified_passive_off(self): """as of 0.8 no SQL is emitted for is_modified() regardless of the passive flag""" User, Address = self._default_mapping_fixture() s = Session() u = User(name='fred', addresses=[ Address(email_address='foo')]) s.add(u) s.commit() u.id def go(): assert not s.is_modified(u) self.assert_sql_count( testing.db, go, 0 ) s.expire_all() u.name = 'newname' # can't predict result here # deterministically, depending on if # 'name' or 'addresses' is tested first mod = s.is_modified(u) addresses_loaded = 'addresses' in u.__dict__ assert mod is not addresses_loaded
def upgrade(): session = Session(bind=op.get_bind()) for i, item in enumerate(session.query(Salad).order_by(Salad.position)): item.position = i session.commit()
def create_scheduled_text(text_data): engine = create_engine(settings['ENGINE_STRING']) session = Session(engine) try: st = ScheduledText() validation_code = randint(100000,999999) st.phone_number = text_data["phone_number"] st.route_number = text_data["route_number"] st.stop_id = text_data["stop_id"] st.hour = int(text_data["hour"]) st.minute = int(text_data["minute"]) if text_data["ampm"] == "pm": st.hour += 12 st.max_minutes = text_data["max_minutes"] st.validation_code = validation_code session.add(st) session.commit() except: session.rollback() raise return st
def _two_obj_fixture(self): e1 = Engineer(name='wally') e2 = Engineer(name='dilbert', reports_to=e1) sess = Session() sess.add_all([e1, e2]) sess.commit() return sess
def test_map_to_select(self): Base, Child = self.classes.Base, self.classes.Child base, child = self.tables.base, self.tables.child base_select = select([base]).alias() mapper( Base, base_select, polymorphic_on=base_select.c.type, polymorphic_identity="base", ) mapper(Child, child, inherits=Base, polymorphic_identity="child") sess = Session() # 2. use an id other than "1" here so can't rely on # the two inserts having the same id c1 = Child(id=12, name="c1") sess.add(c1) sess.commit() sess.close() c1 = sess.query(Child).one() eq_(c1.name, "c1")
def update_song_id3(engine, song_name, new_name='', new_artist='', new_album=''): session = Session(bind=engine) song = session.query(Song).filter(Song.name == song_name) try: song_path = song.one().path pass except Exception: song_path = './song.mp3' updated = {} if new_name: updated['name'] = new_name if new_artist: updated['artist'] = new_artist if new_album: updated['album'] = new_album song.update(updated) session.commit() original_song = EasyID3(song_path) for key, value in updated.items(): if key == 'name': key = 'title' original_song[key] = value original_song.save()
def test_cast_type(self): Json = self.classes.Json s = Session(testing.db) j = Json(json={'field': 10}) s.add(j) s.commit() jq = s.query(Json).filter(Json.int_field == 10).one() eq_(j.id, jq.id) jq = s.query(Json).filter(Json.text_field == '10').one() eq_(j.id, jq.id) jq = s.query(Json).filter(Json.json_field.astext == '10').one() eq_(j.id, jq.id) jq = s.query(Json).filter(Json.text_field == 'wrong').first() is_(jq, None) j.json = {'field': True} s.commit() jq = s.query(Json).filter(Json.text_field == 'true').one() eq_(j.id, jq.id)
def test_one_to_many_on_m2o(self): Node, nodes = self.classes.Node, self.tables.nodes mapper(Node, nodes, properties={ 'children': relationship(Node, backref=sa.orm.backref('parentnode', remote_side=nodes.c.name, passive_updates=False), )}) sess = Session() n1 = Node(name='n1') sess.add(n1) n2 = Node(name='n11', parentnode=n1) n3 = Node(name='n12', parentnode=n1) n4 = Node(name='n13', parentnode=n1) sess.add_all([n2, n3, n4]) sess.commit() n1.name = 'new n1' sess.commit() eq_(['new n1', 'new n1', 'new n1'], [n.parent for n in sess.query(Node).filter( Node.name.in_(['n11', 'n12', 'n13']))])
def test_instance_lazy_relation_loaders(self): users, addresses = (self.tables.users, self.tables.addresses) mapper(User, users, properties={ 'addresses': relationship(Address, lazy='noload') }) mapper(Address, addresses) sess = Session() u1 = User(name='ed', addresses=[ Address( email_address='*****@*****.**', ) ]) sess.add(u1) sess.commit() sess.close() u1 = sess.query(User).options( lazyload(User.addresses) ).first() u2 = pickle.loads(pickle.dumps(u1)) sess = Session() sess.add(u2) assert u2.addresses
def test_collection(self): users, addresses, Address = (self.tables.users, self.tables.addresses, self.classes.Address) canary = Mock() class User(fixtures.ComparableEntity): @validates('addresses') def validate_address(self, key, ad): canary(key, ad) assert '@' in ad.email_address return ad mapper(User, users, properties={ 'addresses': relationship(Address)} ) mapper(Address, addresses) sess = Session() u1 = User(name='edward') a0 = Address(email_address='noemail') assert_raises(AssertionError, u1.addresses.append, a0) a1 = Address(id=15, email_address='*****@*****.**') u1.addresses.append(a1) eq_(canary.mock_calls, [call('addresses', a0), call('addresses', a1)]) sess.add(u1) sess.commit() eq_( sess.query(User).filter_by(name='edward').one(), User(name='edward', addresses=[Address(email_address='*****@*****.**')]) )
def test_09_pickle(self): users = self.tables.users mapper(User, users) sess = Session() sess.add(User(id=1, name='ed')) sess.commit() sess.close() inst = User(id=1, name='ed') del inst._sa_instance_state state = sa_state.InstanceState.__new__(sa_state.InstanceState) state_09 = { 'class_': User, 'modified': False, 'committed_state': {}, 'instance': inst, 'callables': {'name': state, 'id': state}, 'key': (User, (1,)), 'expired': True} manager = instrumentation._SerializeManager.__new__( instrumentation._SerializeManager) manager.class_ = User state_09['manager'] = manager state.__setstate__(state_09) eq_(state.expired_attributes, {'name', 'id'}) sess = Session() sess.add(inst) eq_(inst.name, 'ed') # test identity_token expansion eq_(sa.inspect(inst).key, (User, (1, ), None))
def test_11_pickle(self): users = self.tables.users mapper(User, users) sess = Session() u1 = User(id=1, name='ed') sess.add(u1) sess.commit() sess.close() manager = instrumentation._SerializeManager.__new__( instrumentation._SerializeManager) manager.class_ = User state_11 = { 'class_': User, 'modified': False, 'committed_state': {}, 'instance': u1, 'manager': manager, 'key': (User, (1,)), 'expired_attributes': set(), 'expired': True} state = sa_state.InstanceState.__new__(sa_state.InstanceState) state.__setstate__(state_11) eq_(state.identity_token, None) eq_(state.identity_key, (User, (1,), None))
def _persist(self, json_str): logger = logging.getLogger(__name__) if json_str is None: return engine = create_engine(DB_CONN) Session = sessionmaker(bind=engine) session = Session() json_obj = json.loads(json_str) # dummy id's if none provided. TODO FIX AFTER PROTOTYPE # getting an element from json object returns a 'list', not a single element acct_id = json_obj['accountId'] if acct_id is None or not self.represents_int(acct_id[0]): logger.debug("PixelEventLogger._persist: account Id provided was invalid. Seeting to default value 0") account_id = 0 else: account_id = acct_id[0] cust_id = json_obj['customerId'] if cust_id is None or not self.represents_int(cust_id[0]): logger.debug("PixelEventLogger._persist: customer Id provided was invalid. Seeting to default value 0") customer_id = 0 else: customer_id = cust_id[0] logger.debug("PixelEventLogger._persist: Saving pixel event..." + json_str) pixevent = PixelEvent(account_id=account_id, customer_id=customer_id, doc=json_obj) session.add(pixevent) session.commit()
def test_warning_on_using_inactive_session_rollback_evt(self): users, User = self.tables.users, self.classes.User mapper(User, users) sess = Session() u1 = User(id=1, name='u1') sess.add(u1) sess.commit() u3 = User(name='u3') @event.listens_for(sess, "after_rollback") def evt(s): sess.add(u3) sess.add(User(id=1, name='u2')) def go(): assert_raises( orm_exc.FlushError, sess.flush ) assert_warnings(go, ["Session's state has been changed on a " "non-active transaction - this state " "will be discarded."], ) assert u3 not in sess
def test_dirty_state_transferred_deep_nesting(self): User, users = self.classes.User, self.tables.users mapper(User, users) s = Session(testing.db) u1 = User(name='u1') s.add(u1) s.commit() nt1 = s.begin_nested() nt2 = s.begin_nested() u1.name = 'u2' assert attributes.instance_state(u1) not in nt2._dirty assert attributes.instance_state(u1) not in nt1._dirty s.flush() assert attributes.instance_state(u1) in nt2._dirty assert attributes.instance_state(u1) not in nt1._dirty s.commit() assert attributes.instance_state(u1) in nt2._dirty assert attributes.instance_state(u1) in nt1._dirty s.rollback() assert attributes.instance_state(u1).expired eq_(u1.name, 'u1')
def test_unlock_scenario_structure(fresh_database_config, scenario_manager): """Test that unlock method removes records from the table for locks. 1. Create 3 scenarios with scenario manager. 2. Create records in the ScenarioStructureLock table for the first and third scenarios. 3. Unlock the first scenario. 4. Check that the ScenarioStructureLock table contains only one record for the third scenario. """ scenarios = scenario_manager.create_scenarios([ NewScenarioSpec(name=u'First'), NewScenarioSpec(name=u'Second'), NewScenarioSpec(name=u'Third'), ]) session = Session(bind=create_engine(fresh_database_config.get_connection_string())) session.add_all([ ScenarioStructureLockRecord(scenario_id=scenarios[0].scenario_id), ScenarioStructureLockRecord(scenario_id=scenarios[2].scenario_id), ]) session.commit() scenarios[0]._unlock_structure() session = Session(bind=create_engine(fresh_database_config.get_connection_string())) lock_record = session.query(ScenarioStructureLockRecord).one() assert lock_record.scenario_id == scenarios[2].scenario_id, "Wrong scenario has been unlocked"
def setUpClass(cls) -> None: """ Add a service to the DB """ AcceptanceTestCase.setUpClass() cls.service_name = 'Testing Service' cls.description = 'Description for the Testing Service' cls.job_registration_schema = JSONSchema( title='Job Registration Schema', description='Must be fulfilled for an experiment' ).dump(cls.JobRegistrationSchema()) cls.job_result_schema = JSONSchema( title='Job Result Schema', description='Must be fulfilled to post results' ).dump(cls.JobRegistrationSchema()) session = Session(bind=APP_FACTORY.engine, expire_on_commit=False) service_list = ServiceList(session) cls.service = service_list.new( cls.service_name, cls.description, cls.job_registration_schema, cls.job_result_schema ) session.commit()
def test_noload_append(self): # test that a load of User.addresses is not emitted # when flushing an append User, Address = self._user_address_fixture() sess = Session() u1 = User(name="jack", addresses=[Address(email_address="a1")]) sess.add(u1) sess.commit() u1_id = u1.id sess.expire_all() u1.addresses.append(Address(email_address='a2')) self.assert_sql_execution( testing.db, sess.flush, CompiledSQL( "SELECT users.id AS users_id, users.name AS users_name " "FROM users WHERE users.id = :param_1", lambda ctx: [{"param_1": u1_id}]), CompiledSQL( "INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", lambda ctx: [{'email_address': 'a2', 'user_id': u1_id}] ) )
def test_11_pickle(self): users = self.tables.users mapper(User, users) sess = Session() u1 = User(id=1, name="ed") sess.add(u1) sess.commit() sess.close() manager = instrumentation._SerializeManager.__new__( instrumentation._SerializeManager ) manager.class_ = User state_11 = { "class_": User, "modified": False, "committed_state": {}, "instance": u1, "manager": manager, "key": (User, (1,)), "expired_attributes": set(), "expired": True, } state = sa_state.InstanceState.__new__(sa_state.InstanceState) state.__setstate__(state_11) eq_(state.identity_token, None) eq_(state.identity_key, (User, (1,), None))
def test_09_pickle(self): users = self.tables.users mapper(User, users) sess = Session() sess.add(User(id=1, name="ed")) sess.commit() sess.close() inst = User(id=1, name="ed") del inst._sa_instance_state state = sa_state.InstanceState.__new__(sa_state.InstanceState) state_09 = { "class_": User, "modified": False, "committed_state": {}, "instance": inst, "callables": {"name": state, "id": state}, "key": (User, (1,)), "expired": True, } manager = instrumentation._SerializeManager.__new__( instrumentation._SerializeManager ) manager.class_ = User state_09["manager"] = manager state.__setstate__(state_09) eq_(state.expired_attributes, {"name", "id"}) sess = Session() sess.add(inst) eq_(inst.name, "ed") # test identity_token expansion eq_(sa.inspect(inst).key, (User, (1,), None))
def test_state_creation_by_scenario(engine_with_scenario): """Check that scenario view state could be created by sqlalchemy. 1. Create ScenarioViewState table by sqlalchemy means. 2. Add record. 3. Check that new record could be addressed. """ session = Session(bind=engine_with_scenario) now = datetime.utcnow() scenario = Scenario(created=now) name = u"New name" description = u"Some description" changed = datetime.utcnow() state = ScenarioViewState( scenario=scenario, name=name, description=description, changed=changed) session.add_all((scenario, state)) session.commit() assert state.state_id is not None, "New ScenarioViewState does not obtain ID" assert scenario.view_states.one().state_id == state.state_id, ( "Wrong ID of scenario view state" ) check_engine = create_engine(engine_with_scenario.url) check_session = Session(bind=check_engine) obtained_record = check_session.query(ScenarioViewState).first() assert obtained_record.state_id == state.state_id, "Wrong state ID" assert obtained_record.scenario_id == scenario.scenario_id, "Wrong scenario ID" assert obtained_record.name == name, "Wrong name value" assert obtained_record.description == description, "Wrong description value" assert obtained_record.changed == changed, "wrong changed value"
def test_scalar(self): users = self.tables.users canary = Mock() class User(fixtures.ComparableEntity): @validates('name') def validate_name(self, key, name): canary(key, name) ne_(name, 'fred') return name + ' modified' mapper(User, users) sess = Session() u1 = User(name='ed') eq_(u1.name, 'ed modified') assert_raises(AssertionError, setattr, u1, "name", "fred") eq_(u1.name, 'ed modified') eq_(canary.mock_calls, [call('name', 'ed'), call('name', 'fred')]) sess.add(u1) sess.commit() eq_( sess.query(User).filter_by(name='ed modified').one(), User(name='ed') )
def test_one_to_many_on_o2m(self): Node, nodes = self.classes.Node, self.tables.nodes mapper(Node, nodes, properties={ 'children': relationship(Node, backref=sa.orm.backref('parentnode', remote_side=nodes.c.name), passive_updates=False )}) sess = Session() n1 = Node(name='n1') n1.children.append(Node(name='n11')) n1.children.append(Node(name='n12')) n1.children.append(Node(name='n13')) sess.add(n1) sess.commit() n1.name = 'new n1' sess.commit() eq_(n1.children[1].parent, 'new n1') eq_(['new n1', 'new n1', 'new n1'], [n.parent for n in sess.query(Node).filter( Node.name.in_(['n11', 'n12', 'n13']))])
def test_concurrent_mod_err_noexpire_on_commit(self): sess = self._fixture(expire_on_commit=False) f1 = self.classes.Foo(value='f1') sess.add(f1) sess.commit() # here, we're not expired overall, so no load occurs and we # stay without a version id, unless we've emitted # a SELECT for it within the flush. f1.value s2 = Session(expire_on_commit=False) f2 = s2.query(self.classes.Foo).first() f2.value = 'f2' s2.commit() f1.value = 'f3' assert_raises_message( orm.exc.StaleDataError, r"UPDATE statement on table 'version_table' expected to " r"update 1 row\(s\); 0 were matched.", sess.commit )
def handle_cursor( # pylint: disable=too-many-locals cls, cursor: Any, query: Query, session: Session ) -> None: """Updates progress information""" from pyhive import hive # pylint: disable=no-name-in-module unfinished_states = ( hive.ttypes.TOperationState.INITIALIZED_STATE, hive.ttypes.TOperationState.RUNNING_STATE, ) polled = cursor.poll() last_log_line = 0 tracking_url = None job_id = None query_id = query.id while polled.operationState in unfinished_states: query = session.query(type(query)).filter_by(id=query_id).one() if query.status == QueryStatus.STOPPED: cursor.cancel() break log = cursor.fetch_logs() or "" if log: log_lines = log.splitlines() progress = cls.progress(log_lines) logger.info( "Query %s: Progress total: %s", str(query_id), str(progress) ) needs_commit = False if progress > query.progress: query.progress = progress needs_commit = True if not tracking_url: tracking_url = cls.get_tracking_url(log_lines) if tracking_url: job_id = tracking_url.split("/")[-2] logger.info( "Query %s: Found the tracking url: %s", str(query_id), tracking_url, ) tracking_url = tracking_url_trans(tracking_url) logger.info( "Query %s: Transformation applied: %s", str(query_id), tracking_url, ) query.tracking_url = tracking_url logger.info("Query %s: Job id: %s", str(query_id), str(job_id)) needs_commit = True if job_id and len(log_lines) > last_log_line: # Wait for job id before logging things out # this allows for prefixing all log lines and becoming # searchable in something like Kibana for l in log_lines[last_log_line:]: logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l) last_log_line = len(log_lines) if needs_commit: session.commit() time.sleep(hive_poll_interval) polled = cursor.poll()
def persist(session: orm.Session, distributions: List[BaseDistribution]) -> None: if len(distributions) == 0: return session.bulk_save_objects(distributions) session.commit()
def migrate_watch_assoc(): """ Migrates watch targets to watch assocs :return: """ # Data migration - online mode only if context.is_offline_mode(): logger.warning('Data migration skipped in the offline mode') return def strip(x): if x is None: return None return x.strip() def target_key(t): scheme, host, port = t.scan_scheme, t.scan_host, t.scan_port if scheme is None: scheme = 'https' if port is not None: port = int(port) if port is None: port = 443 if scheme == 'http': scheme = 'https' if scheme == 'htttp': scheme = 'https' if port == 80 or port <= 10 or port >= 65535: port = 443 host = strip(host) if host is not None: if host.startswith('*.'): host = host[2:] if host.startswith('%.'): host = host[2:] return scheme, host, port target_db = {} already_assoc = set() duplicates = [] bind = op.get_bind() sess = BaseSession(bind=bind) it = sess.query(DbWatchTarget).yield_per(1000) for rec in it: ck = target_key(rec) rec_assoc = rec if ck in target_db: rec_assoc = target_db[ck] duplicates.append(rec.id) else: target_db[ck] = rec rec.scan_scheme = ck[0] rec.scan_host = ck[1] rec.scan_port = ck[2] if rec.user_id is None: continue cur_assoc_key = rec_assoc.id, rec.user_id if cur_assoc_key in already_assoc: print('already assoc: %s' % (cur_assoc_key, )) continue already_assoc.add(cur_assoc_key) assoc = DbWatchAssoc() assoc.scan_type = 1 assoc.created_at = rec_assoc.created_at assoc.updated_at = rec_assoc.updated_at assoc.scan_periodicity = rec_assoc.scan_periodicity assoc.user_id = rec.user_id # actual record! assoc.watch_id = rec_assoc.id sess.add(assoc) sess.commit() # remove duplicates if len(duplicates) > 0: sess.query(DbWatchTarget).filter(DbWatchTarget.id.in_(list(duplicates))) \ .delete(synchronize_session='fetch') sess.commit() print('Removed %s duplicates %s' % (len(duplicates), duplicates))
def remove(self, db_session: Session, *, id: int) -> ModelType: obj = db_session.query(self.model).get(id) db_session.delete(obj) db_session.commit() return obj
def peer_ip_address_set(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer: db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one() db_peer.address = peer.address sess.add(db_peer) sess.commit() return peer.from_orm(db_peer)
def create_user_item(db: Session, item: schemas.ItemCreate, user_id: int): db_item = models.Item(**item.dict(), owner_id=user_id) db.add(db_item) db.commit() db.refresh(db_item) return db_item
engineer_name="engineer1", primary_language="java", ), Person(name="joesmith"), Engineer( name="wally", status="CGG", engineer_name="engineer2", primary_language="python", ), Manager(name="jsmith", status="ABA", manager_name="manager2"), ], ) session.add(c) session.commit() c = session.query(Company).get(1) for e in c.employees: print(e, inspect(e).key, e.company) assert set([e.name for e in c.employees]) == set( ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"]) print("\n") dilbert = session.query(Person).filter_by(name="dilbert").one() dilbert2 = session.query(Engineer).filter_by(name="dilbert").one() assert dilbert is dilbert2 dilbert.engineer_name = "hes dilbert!" session.commit()
def delete_player(db: Session, player_id): db_player = db.query(Player).filter(Player.id == player_id).first() db_player.name = f"rnd-{uuid.uuid4()}" db.add(db_player) db.commit()
def delete_product(db : Session,product : schemas.Product): db.delete(product) db.commit() return "Successfully deleted"
def create_travels(db: Session, travels: List[schemas.TravelCreate]): db_travels = [models.Travel(**travel.dict()) for travel in travels] db.add_all(db_travels) db.commit() return db_travels
class Librarian: """ Librarian Class (Databasing Utilities for Miners) """ # TODO: Place connection string in environments file PSQL_CONN = "postgresql+psycopg2://test:test@localhost:5432/cryptkeeper_raw" def __init__(self, schema): self.ENGINE = create_engine(self.PSQL_CONN) self.SESSION = Session(bind=self.ENGINE) self.SCHEMA = schema # Public Methods def bulkInsert(self, items): """ Inserts a list of items into the table of supplied schema, implementing `bulk_insert_mappings`. Unlike other methods, `bulkInsert` akes a list of dictionaries and performs the mapping to the schema as opposed to taking items that have already been mapped according to the schema thereunto pertaining. """ try: self.SESSION.bulk_insert_mappings(self.SCHEMA, items) self.SESSION.commit() except SQLAlchemyError as err: self.SESSION.rollback() print("Error with bulk insert: %s" % (err)) def bulkInsertDoNothingOnConflict(self, items, unique_fields): """ Either inserts items into the table of supplied schema or does nothing. Requires items to be mapped prior to being passed to this method. """ try: statement = postgresql \ .insert(self.SCHEMA.__table__) \ .values(items) \ .on_conflict_do_nothing(index_elements = unique_fields) self.SESSION.execute(statement) self.SESSION.commit() except SQLAlchemyError as err: self.SESSION.rollback() print("Error with bulk insert: %s" % (err)) def bulkUpsert(self, items, unique_fields): """ Either inserts items into the table of supplied schema or updates conflicting field. Requires items to be mapped prior to being passed to this method. """ try: for item in items: statement = postgresql \ .insert(self.SCHEMA.__table__) \ .values(item) \ .on_conflict_do_update( index_elements = unique_fields, set_ = item ) self.SESSION.execute(statement) self.SESSION.commit() except SQLAlchemyError as err: self.SESSION.rollback() print("Error with bulk upsert: %s" % (err)) def getLastDaysEntries(self): """ Select all entries added in the last day """ one_day_ago = datetime.now() - timedelta(days=1) try: return self.SESSION.query(self.SCHEMA).filter( and_(self.SCHEMA.created >= one_day_ago)).all() except SQLAlchemyError as err: print("Error with select query: %s" % (err))
def update(self, db: Session, *, model_id: int = None, obj: ModelType = None, query: Dict[str, Any] = None, data: Union[UpdateSchemaType, Dict[str, Any]], is_return_obj: bool = False, is_transaction: bool = False) -> ModelType: """ 单个对象更新 更新有两种方式: 方式一: 传入需要更新的模型和更新的数据。 方式二: 传入id和更新的数据 注意: 如果传入模型来进行更新,则'is_return_obj=False'失效,返回更新后的模型 :param db: :param model_id: 模型ID :param obj: 模型对象 :param query: 模型查询参数 :param data: 需要更新的数据 :param is_return_obj: 是否需要返回模型数据,默认为False,只返回更新成功的行数 :param is_transaction: 是否开启事务功能 :return: update_count or obj or None """ if not any((model_id, obj, query)): raise ValueError('At least one of [model_id、query、obj] exists') if isinstance(data, dict): update_data = data else: update_data = data.dict(exclude_unset=True) if not is_return_obj and not obj: if model_id: update_count = db.query(self.model).filter( self.model.id == model_id).update(update_data) else: update_count = db.query( self.model).filter_by(**query).update(update_data) if not is_transaction: db.commit() return update_count else: if not obj: if model_id: obj = self.get(db, model_id) else: obj = self.get_one(db, **query) if obj: obj_data = jsonable_encoder(obj) for field in obj_data: if field in update_data: setattr(obj, field, update_data[field]) db.add(obj) if not is_transaction: db.commit() db.refresh(obj) return obj
def create_my_story(db: Session, my_story: schemas.MyStoryCreate): db_my_story = models.MyStory(**my_story.dict()) db.add(db_my_story) db.commit() return db_my_story
def InsertWarehouseEmail(self, binaryBody): session = Session(self.Engine) entity = self.WarehouseEmailSource(Content=binaryBody) session.add(entity) session.commit() session.close()
def update_latest_my_story(db: Session, story: schemas.Story, my_story): story.latest_my_story = my_story db.add(story) db.commit()