Ejemplo n.º 1
0
    def test_multiple_key_replaced_by_update(self):
        users, User = self.tables.users, self.classes.User

        mapper(User, users)

        u1 = User(name='u1')
        u2 = User(name='u2')
        u3 = User(name='u3')

        s = Session()
        s.add_all([u1, u2, u3])
        s.commit()

        s.delete(u1)
        s.delete(u2)
        s.flush()

        u3.name = 'u1'
        s.flush()

        u3.name = 'u2'
        s.flush()

        s.rollback()

        assert u1 in s
        assert u2 in s
        assert u3 in s

        assert s.identity_map[(User, ('u1',))] is u1
        assert s.identity_map[(User, ('u2',))] is u2
        assert s.identity_map[(User, ('u3',))] is u3
Ejemplo n.º 2
0
    def test_set_composite_attrs_via_selectable(self):
        Values, CustomValues, values, Descriptions, descriptions = (self.classes.Values,
                                self.classes.CustomValues,
                                self.tables.values,
                                self.classes.Descriptions,
                                self.tables.descriptions)

        session = Session()
        d = Descriptions(
            custom_descriptions = CustomValues('Color', 'Number'),
            values =[
                Values(custom_values = CustomValues('Red', '5')),
                Values(custom_values=CustomValues('Blue', '1'))
            ]
        )

        session.add(d)
        session.commit()
        eq_(
            testing.db.execute(descriptions.select()).fetchall(),
            [(1, 'Color', 'Number')]
        )
        eq_(
            testing.db.execute(values.select()).fetchall(),
            [(1, 1, 'Red', '5'), (2, 1, 'Blue', '1')]
        )
Ejemplo n.º 3
0
    def test_is_modified_passive_on(self):
        User, Address = self._default_mapping_fixture()

        s = Session()
        u = User(name='fred', addresses=[Address(email_address='foo')])
        s.add(u)
        s.commit()

        u.id
        def go():
            assert not s.is_modified(u, passive=True)
        self.assert_sql_count(
            testing.db,
            go,
            0
        )

        u.name = 'newname'
        def go():
            assert s.is_modified(u, passive=True)
        self.assert_sql_count(
            testing.db,
            go,
            0
        )
Ejemplo n.º 4
0
    def test_continue_flushing_on_commit(self):
        """test that post-flush actions get flushed also if
        we're in commit()"""
        users, User = self.tables.users, self.classes.User

        mapper(User, users)
        sess = Session()

        to_flush = [User(name='ed'), User(name='jack'), User(name='wendy')]

        @event.listens_for(sess, "after_flush_postexec")
        def add_another_user(session, ctx):
            if to_flush:
                session.add(to_flush.pop(0))

        x = [1]

        @event.listens_for(sess, "after_commit")  # noqa
        def add_another_user(session):
            x[0] += 1

        sess.add(to_flush.pop())
        sess.commit()
        eq_(x, [2])
        eq_(
            sess.scalar(select([func.count(users.c.id)])), 3
        )
Ejemplo n.º 5
0
    def test_explicit_expunge_deleted(self):
        users, User = self.tables.users, self.classes.User

        mapper(User, users)
        sess = Session()
        sess.add(User(name='x'))
        sess.commit()

        u1 = sess.query(User).first()
        sess.delete(u1)

        sess.flush()

        assert was_deleted(u1)
        assert u1 not in sess
        assert object_session(u1) is sess

        sess.expunge(u1)
        assert was_deleted(u1)
        assert u1 not in sess
        assert object_session(u1) is None

        sess.rollback()
        assert was_deleted(u1)
        assert u1 not in sess
        assert object_session(u1) is None
Ejemplo n.º 6
0
    def test_auto_detach_on_gc_session(self):
        users, User = self.tables.users, self.classes.User

        mapper(User, users)

        sess = Session()

        u1 = User(name='u1')
        sess.add(u1)
        sess.commit()

        # can't add u1 to Session,
        # already belongs to u2
        s2 = Session()
        assert_raises_message(
            sa.exc.InvalidRequestError,
            r".*is already attached to session",
            s2.add, u1
        )

        # garbage collect sess
        del sess
        gc_collect()

        # s2 lets it in now despite u1 having
        # session_key
        s2.add(u1)
        assert u1 in s2
Ejemplo n.º 7
0
    def test_is_modified_passive_off(self):
        """as of 0.8 no SQL is emitted for is_modified()
        regardless of the passive flag"""

        User, Address = self._default_mapping_fixture()

        s = Session()
        u = User(name='fred', addresses=[
                    Address(email_address='foo')])
        s.add(u)
        s.commit()

        u.id
        def go():
            assert not s.is_modified(u)
        self.assert_sql_count(
            testing.db,
            go,
            0
        )

        s.expire_all()
        u.name = 'newname'

        # can't predict result here
        # deterministically, depending on if
        # 'name' or 'addresses' is tested first
        mod = s.is_modified(u)
        addresses_loaded = 'addresses' in u.__dict__
        assert mod is not addresses_loaded
def upgrade():
    session = Session(bind=op.get_bind())

    for i, item in enumerate(session.query(Salad).order_by(Salad.position)):
        item.position = i

    session.commit()
Ejemplo n.º 9
0
def create_scheduled_text(text_data):
	engine = create_engine(settings['ENGINE_STRING'])
	session = Session(engine)

	try:
		st = ScheduledText()

		validation_code = randint(100000,999999)
		
		st.phone_number = text_data["phone_number"]
		st.route_number = text_data["route_number"]
		st.stop_id =	  text_data["stop_id"]
		st.hour =		  int(text_data["hour"])
		st.minute =	   	  int(text_data["minute"])
		if text_data["ampm"] == "pm":
			st.hour += 12

		st.max_minutes =  text_data["max_minutes"]
		st.validation_code = validation_code

		session.add(st)
		session.commit()

	except:
		session.rollback()
		raise

	return st
Ejemplo n.º 10
0
 def _two_obj_fixture(self):
     e1 = Engineer(name='wally')
     e2 = Engineer(name='dilbert', reports_to=e1)
     sess = Session()
     sess.add_all([e1, e2])
     sess.commit()
     return sess
Ejemplo n.º 11
0
    def test_map_to_select(self):
        Base, Child = self.classes.Base, self.classes.Child
        base, child = self.tables.base, self.tables.child

        base_select = select([base]).alias()
        mapper(
            Base,
            base_select,
            polymorphic_on=base_select.c.type,
            polymorphic_identity="base",
        )
        mapper(Child, child, inherits=Base, polymorphic_identity="child")

        sess = Session()

        # 2. use an id other than "1" here so can't rely on
        # the two inserts having the same id
        c1 = Child(id=12, name="c1")
        sess.add(c1)

        sess.commit()
        sess.close()

        c1 = sess.query(Child).one()
        eq_(c1.name, "c1")
Ejemplo n.º 12
0
def update_song_id3(engine, song_name, new_name='', new_artist='',
                    new_album=''):
    session = Session(bind=engine)
    song = session.query(Song).filter(Song.name == song_name)
    try:
        song_path = song.one().path
        pass
    except Exception:
        song_path = './song.mp3'

    updated = {}

    if new_name:
        updated['name'] = new_name
    if new_artist:
        updated['artist'] = new_artist
    if new_album:
        updated['album'] = new_album
    song.update(updated)
    session.commit()

    original_song = EasyID3(song_path)
    for key, value in updated.items():
        if key == 'name':
            key = 'title'
        original_song[key] = value
    original_song.save()
Ejemplo n.º 13
0
    def test_cast_type(self):
        Json = self.classes.Json
        s = Session(testing.db)

        j = Json(json={'field': 10})
        s.add(j)
        s.commit()

        jq = s.query(Json).filter(Json.int_field == 10).one()
        eq_(j.id, jq.id)

        jq = s.query(Json).filter(Json.text_field == '10').one()
        eq_(j.id, jq.id)

        jq = s.query(Json).filter(Json.json_field.astext == '10').one()
        eq_(j.id, jq.id)

        jq = s.query(Json).filter(Json.text_field == 'wrong').first()
        is_(jq, None)

        j.json = {'field': True}
        s.commit()

        jq = s.query(Json).filter(Json.text_field == 'true').one()
        eq_(j.id, jq.id)
Ejemplo n.º 14
0
    def test_one_to_many_on_m2o(self):
        Node, nodes = self.classes.Node, self.tables.nodes

        mapper(Node, nodes, properties={
            'children': relationship(Node,
                                 backref=sa.orm.backref('parentnode',
                                            remote_side=nodes.c.name,
                                            passive_updates=False),
                                 )})

        sess = Session()
        n1 = Node(name='n1')
        sess.add(n1)
        n2 = Node(name='n11', parentnode=n1)
        n3 = Node(name='n12', parentnode=n1)
        n4 = Node(name='n13', parentnode=n1)
        sess.add_all([n2, n3, n4])
        sess.commit()

        n1.name = 'new n1'
        sess.commit()
        eq_(['new n1', 'new n1', 'new n1'],
            [n.parent
             for n in sess.query(Node).filter(
                 Node.name.in_(['n11', 'n12', 'n13']))])
Ejemplo n.º 15
0
    def test_instance_lazy_relation_loaders(self):
        users, addresses = (self.tables.users,
                                self.tables.addresses)

        mapper(User, users, properties={
            'addresses': relationship(Address, lazy='noload')
        })
        mapper(Address, addresses)

        sess = Session()
        u1 = User(name='ed', addresses=[
                        Address(
                            email_address='*****@*****.**',
                        )
                ])

        sess.add(u1)
        sess.commit()
        sess.close()

        u1 = sess.query(User).options(
                                lazyload(User.addresses)
                            ).first()
        u2 = pickle.loads(pickle.dumps(u1))

        sess = Session()
        sess.add(u2)
        assert u2.addresses
Ejemplo n.º 16
0
    def test_collection(self):
        users, addresses, Address = (self.tables.users,
                                self.tables.addresses,
                                self.classes.Address)

        canary = Mock()
        class User(fixtures.ComparableEntity):
            @validates('addresses')
            def validate_address(self, key, ad):
                canary(key, ad)
                assert '@' in ad.email_address
                return ad

        mapper(User, users, properties={
                'addresses': relationship(Address)}
        )
        mapper(Address, addresses)
        sess = Session()
        u1 = User(name='edward')
        a0 = Address(email_address='noemail')
        assert_raises(AssertionError, u1.addresses.append, a0)
        a1 = Address(id=15, email_address='*****@*****.**')
        u1.addresses.append(a1)
        eq_(canary.mock_calls, [call('addresses', a0), call('addresses', a1)])
        sess.add(u1)
        sess.commit()

        eq_(
            sess.query(User).filter_by(name='edward').one(),
            User(name='edward', addresses=[Address(email_address='*****@*****.**')])
        )
Ejemplo n.º 17
0
    def test_09_pickle(self):
        users = self.tables.users
        mapper(User, users)
        sess = Session()
        sess.add(User(id=1, name='ed'))
        sess.commit()
        sess.close()

        inst = User(id=1, name='ed')
        del inst._sa_instance_state

        state = sa_state.InstanceState.__new__(sa_state.InstanceState)
        state_09 = {
            'class_': User,
            'modified': False,
            'committed_state': {},
            'instance': inst,
            'callables': {'name': state, 'id': state},
            'key': (User, (1,)),
            'expired': True}
        manager = instrumentation._SerializeManager.__new__(
            instrumentation._SerializeManager)
        manager.class_ = User
        state_09['manager'] = manager
        state.__setstate__(state_09)
        eq_(state.expired_attributes, {'name', 'id'})

        sess = Session()
        sess.add(inst)
        eq_(inst.name, 'ed')
        # test identity_token expansion
        eq_(sa.inspect(inst).key, (User, (1, ), None))
Ejemplo n.º 18
0
    def test_11_pickle(self):
        users = self.tables.users
        mapper(User, users)
        sess = Session()
        u1 = User(id=1, name='ed')
        sess.add(u1)
        sess.commit()

        sess.close()

        manager = instrumentation._SerializeManager.__new__(
            instrumentation._SerializeManager)
        manager.class_ = User

        state_11 = {

            'class_': User,
            'modified': False,
            'committed_state': {},
            'instance': u1,
            'manager': manager,
            'key': (User, (1,)),
            'expired_attributes': set(),
            'expired': True}

        state = sa_state.InstanceState.__new__(sa_state.InstanceState)
        state.__setstate__(state_11)

        eq_(state.identity_token, None)
        eq_(state.identity_key, (User, (1,), None))
Ejemplo n.º 19
0
 def _persist(self, json_str):
     logger = logging.getLogger(__name__)
     if json_str is None:
         return
     
     engine = create_engine(DB_CONN)
     Session = sessionmaker(bind=engine)
     session = Session()
     
     json_obj = json.loads(json_str)
     
     # dummy id's if none provided. TODO FIX AFTER PROTOTYPE
     # getting an element from json object returns a 'list', not a single element
     acct_id = json_obj['accountId']
     if acct_id is None or not self.represents_int(acct_id[0]):
         logger.debug("PixelEventLogger._persist: account Id provided was invalid. Seeting to default value 0")
         account_id = 0
     else:
         account_id = acct_id[0]
     cust_id = json_obj['customerId']
     if cust_id is None or not self.represents_int(cust_id[0]):
         logger.debug("PixelEventLogger._persist: customer Id provided was invalid. Seeting to default value 0")            
         customer_id = 0
     else:
         customer_id = cust_id[0]
 
     logger.debug("PixelEventLogger._persist: Saving pixel event..." + json_str)
     pixevent = PixelEvent(account_id=account_id, customer_id=customer_id, doc=json_obj)
     
     session.add(pixevent)  
     session.commit()
Ejemplo n.º 20
0
    def test_warning_on_using_inactive_session_rollback_evt(self):
        users, User = self.tables.users, self.classes.User

        mapper(User, users)
        sess = Session()
        u1 = User(id=1, name='u1')
        sess.add(u1)
        sess.commit()

        u3 = User(name='u3')

        @event.listens_for(sess, "after_rollback")
        def evt(s):
            sess.add(u3)

        sess.add(User(id=1, name='u2'))

        def go():
            assert_raises(
                orm_exc.FlushError, sess.flush
            )

        assert_warnings(go,
                        ["Session's state has been changed on a "
                         "non-active transaction - this state "
                         "will be discarded."],
                        )
        assert u3 not in sess
Ejemplo n.º 21
0
    def test_dirty_state_transferred_deep_nesting(self):
        User, users = self.classes.User, self.tables.users

        mapper(User, users)

        s = Session(testing.db)
        u1 = User(name='u1')
        s.add(u1)
        s.commit()

        nt1 = s.begin_nested()
        nt2 = s.begin_nested()
        u1.name = 'u2'
        assert attributes.instance_state(u1) not in nt2._dirty
        assert attributes.instance_state(u1) not in nt1._dirty
        s.flush()
        assert attributes.instance_state(u1) in nt2._dirty
        assert attributes.instance_state(u1) not in nt1._dirty

        s.commit()
        assert attributes.instance_state(u1) in nt2._dirty
        assert attributes.instance_state(u1) in nt1._dirty

        s.rollback()
        assert attributes.instance_state(u1).expired
        eq_(u1.name, 'u1')
def test_unlock_scenario_structure(fresh_database_config, scenario_manager):
    """Test that unlock method removes records from the table for locks.

    1. Create 3 scenarios with scenario manager.
    2. Create records in the ScenarioStructureLock table for the first and third scenarios.
    3. Unlock the first scenario.
    4. Check that the ScenarioStructureLock table contains only one record for the third scenario.
    """
    scenarios = scenario_manager.create_scenarios([
        NewScenarioSpec(name=u'First'),
        NewScenarioSpec(name=u'Second'),
        NewScenarioSpec(name=u'Third'),
        ])

    session = Session(bind=create_engine(fresh_database_config.get_connection_string()))
    session.add_all([
        ScenarioStructureLockRecord(scenario_id=scenarios[0].scenario_id),
        ScenarioStructureLockRecord(scenario_id=scenarios[2].scenario_id),
        ])
    session.commit()

    scenarios[0]._unlock_structure()

    session = Session(bind=create_engine(fresh_database_config.get_connection_string()))
    lock_record = session.query(ScenarioStructureLockRecord).one()
    assert lock_record.scenario_id == scenarios[2].scenario_id, "Wrong scenario has been unlocked"
    def setUpClass(cls) -> None:
        """
        Add a service to the DB
        """
        AcceptanceTestCase.setUpClass()
        cls.service_name = 'Testing Service'
        cls.description = 'Description for the Testing Service'

        cls.job_registration_schema = JSONSchema(
            title='Job Registration Schema',
            description='Must be fulfilled for an experiment'
        ).dump(cls.JobRegistrationSchema())

        cls.job_result_schema = JSONSchema(
            title='Job Result Schema',
            description='Must be fulfilled to post results'
        ).dump(cls.JobRegistrationSchema())

        session = Session(bind=APP_FACTORY.engine, expire_on_commit=False)

        service_list = ServiceList(session)
        cls.service = service_list.new(
            cls.service_name, cls.description, cls.job_registration_schema,
            cls.job_result_schema
        )
        session.commit()
Ejemplo n.º 24
0
    def test_noload_append(self):
        # test that a load of User.addresses is not emitted
        # when flushing an append
        User, Address = self._user_address_fixture()

        sess = Session()
        u1 = User(name="jack", addresses=[Address(email_address="a1")])
        sess.add(u1)
        sess.commit()

        u1_id = u1.id
        sess.expire_all()

        u1.addresses.append(Address(email_address='a2'))

        self.assert_sql_execution(
            testing.db,
            sess.flush,
            CompiledSQL(
                "SELECT users.id AS users_id, users.name AS users_name "
                "FROM users WHERE users.id = :param_1",
                lambda ctx: [{"param_1": u1_id}]),
            CompiledSQL(
                "INSERT INTO addresses (user_id, email_address) "
                "VALUES (:user_id, :email_address)",
                lambda ctx: [{'email_address': 'a2', 'user_id': u1_id}]
            )
        )
Ejemplo n.º 25
0
    def test_11_pickle(self):
        users = self.tables.users
        mapper(User, users)
        sess = Session()
        u1 = User(id=1, name="ed")
        sess.add(u1)
        sess.commit()

        sess.close()

        manager = instrumentation._SerializeManager.__new__(
            instrumentation._SerializeManager
        )
        manager.class_ = User

        state_11 = {
            "class_": User,
            "modified": False,
            "committed_state": {},
            "instance": u1,
            "manager": manager,
            "key": (User, (1,)),
            "expired_attributes": set(),
            "expired": True,
        }

        state = sa_state.InstanceState.__new__(sa_state.InstanceState)
        state.__setstate__(state_11)

        eq_(state.identity_token, None)
        eq_(state.identity_key, (User, (1,), None))
Ejemplo n.º 26
0
    def test_09_pickle(self):
        users = self.tables.users
        mapper(User, users)
        sess = Session()
        sess.add(User(id=1, name="ed"))
        sess.commit()
        sess.close()

        inst = User(id=1, name="ed")
        del inst._sa_instance_state

        state = sa_state.InstanceState.__new__(sa_state.InstanceState)
        state_09 = {
            "class_": User,
            "modified": False,
            "committed_state": {},
            "instance": inst,
            "callables": {"name": state, "id": state},
            "key": (User, (1,)),
            "expired": True,
        }
        manager = instrumentation._SerializeManager.__new__(
            instrumentation._SerializeManager
        )
        manager.class_ = User
        state_09["manager"] = manager
        state.__setstate__(state_09)
        eq_(state.expired_attributes, {"name", "id"})

        sess = Session()
        sess.add(inst)
        eq_(inst.name, "ed")
        # test identity_token expansion
        eq_(sa.inspect(inst).key, (User, (1,), None))
def test_state_creation_by_scenario(engine_with_scenario):
    """Check that scenario view state could be created by sqlalchemy.

    1. Create ScenarioViewState table by sqlalchemy means.
    2. Add record.
    3. Check that new record could be addressed.
    """
    session = Session(bind=engine_with_scenario)

    now = datetime.utcnow()
    scenario = Scenario(created=now)

    name = u"New name"
    description = u"Some description"
    changed = datetime.utcnow()

    state = ScenarioViewState(
        scenario=scenario, name=name, description=description, changed=changed)

    session.add_all((scenario, state))
    session.commit()

    assert state.state_id is not None, "New ScenarioViewState does not obtain ID"
    assert scenario.view_states.one().state_id == state.state_id, (
        "Wrong ID of scenario view state"
        )

    check_engine = create_engine(engine_with_scenario.url)
    check_session = Session(bind=check_engine)
    obtained_record = check_session.query(ScenarioViewState).first()
    assert obtained_record.state_id == state.state_id, "Wrong state ID"
    assert obtained_record.scenario_id == scenario.scenario_id, "Wrong scenario ID"
    assert obtained_record.name == name, "Wrong name value"
    assert obtained_record.description == description, "Wrong description value"
    assert obtained_record.changed == changed, "wrong changed value"
Ejemplo n.º 28
0
    def test_scalar(self):
        users = self.tables.users
        canary = Mock()

        class User(fixtures.ComparableEntity):
            @validates('name')
            def validate_name(self, key, name):
                canary(key, name)
                ne_(name, 'fred')
                return name + ' modified'

        mapper(User, users)
        sess = Session()
        u1 = User(name='ed')
        eq_(u1.name, 'ed modified')
        assert_raises(AssertionError, setattr, u1, "name", "fred")
        eq_(u1.name, 'ed modified')
        eq_(canary.mock_calls, [call('name', 'ed'), call('name', 'fred')])

        sess.add(u1)
        sess.commit()

        eq_(
            sess.query(User).filter_by(name='ed modified').one(),
            User(name='ed')
        )
Ejemplo n.º 29
0
    def test_one_to_many_on_o2m(self):
        Node, nodes = self.classes.Node, self.tables.nodes

        mapper(Node, nodes, properties={
            'children': relationship(Node,
                                 backref=sa.orm.backref('parentnode',
                                            remote_side=nodes.c.name),
                                passive_updates=False
                                 )})

        sess = Session()
        n1 = Node(name='n1')
        n1.children.append(Node(name='n11'))
        n1.children.append(Node(name='n12'))
        n1.children.append(Node(name='n13'))
        sess.add(n1)
        sess.commit()

        n1.name = 'new n1'
        sess.commit()
        eq_(n1.children[1].parent, 'new n1')
        eq_(['new n1', 'new n1', 'new n1'],
            [n.parent
             for n in sess.query(Node).filter(
                 Node.name.in_(['n11', 'n12', 'n13']))])
Ejemplo n.º 30
0
    def test_concurrent_mod_err_noexpire_on_commit(self):
        sess = self._fixture(expire_on_commit=False)

        f1 = self.classes.Foo(value='f1')
        sess.add(f1)
        sess.commit()

        # here, we're not expired overall, so no load occurs and we
        # stay without a version id, unless we've emitted
        # a SELECT for it within the flush.
        f1.value

        s2 = Session(expire_on_commit=False)
        f2 = s2.query(self.classes.Foo).first()
        f2.value = 'f2'
        s2.commit()

        f1.value = 'f3'

        assert_raises_message(
            orm.exc.StaleDataError,
            r"UPDATE statement on table 'version_table' expected to "
            r"update 1 row\(s\); 0 were matched.",
            sess.commit
        )
Ejemplo n.º 31
0
    def handle_cursor(  # pylint: disable=too-many-locals
        cls, cursor: Any, query: Query, session: Session
    ) -> None:
        """Updates progress information"""
        from pyhive import hive  # pylint: disable=no-name-in-module

        unfinished_states = (
            hive.ttypes.TOperationState.INITIALIZED_STATE,
            hive.ttypes.TOperationState.RUNNING_STATE,
        )
        polled = cursor.poll()
        last_log_line = 0
        tracking_url = None
        job_id = None
        query_id = query.id
        while polled.operationState in unfinished_states:
            query = session.query(type(query)).filter_by(id=query_id).one()
            if query.status == QueryStatus.STOPPED:
                cursor.cancel()
                break

            log = cursor.fetch_logs() or ""
            if log:
                log_lines = log.splitlines()
                progress = cls.progress(log_lines)
                logger.info(
                    "Query %s: Progress total: %s", str(query_id), str(progress)
                )
                needs_commit = False
                if progress > query.progress:
                    query.progress = progress
                    needs_commit = True
                if not tracking_url:
                    tracking_url = cls.get_tracking_url(log_lines)
                    if tracking_url:
                        job_id = tracking_url.split("/")[-2]
                        logger.info(
                            "Query %s: Found the tracking url: %s",
                            str(query_id),
                            tracking_url,
                        )
                        tracking_url = tracking_url_trans(tracking_url)
                        logger.info(
                            "Query %s: Transformation applied: %s",
                            str(query_id),
                            tracking_url,
                        )
                        query.tracking_url = tracking_url
                        logger.info("Query %s: Job id: %s", str(query_id), str(job_id))
                        needs_commit = True
                if job_id and len(log_lines) > last_log_line:
                    # Wait for job id before logging things out
                    # this allows for prefixing all log lines and becoming
                    # searchable in something like Kibana
                    for l in log_lines[last_log_line:]:
                        logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l)
                    last_log_line = len(log_lines)
                if needs_commit:
                    session.commit()
            time.sleep(hive_poll_interval)
            polled = cursor.poll()
Ejemplo n.º 32
0
def persist(session: orm.Session,
            distributions: List[BaseDistribution]) -> None:
    if len(distributions) == 0:
        return
    session.bulk_save_objects(distributions)
    session.commit()
Ejemplo n.º 33
0
def migrate_watch_assoc():
    """
    Migrates watch targets to watch assocs
    :return:
    """
    # Data migration - online mode only
    if context.is_offline_mode():
        logger.warning('Data migration skipped in the offline mode')
        return

    def strip(x):
        if x is None:
            return None
        return x.strip()

    def target_key(t):
        scheme, host, port = t.scan_scheme, t.scan_host, t.scan_port
        if scheme is None:
            scheme = 'https'
        if port is not None:
            port = int(port)
        if port is None:
            port = 443
        if scheme == 'http':
            scheme = 'https'
        if scheme == 'htttp':
            scheme = 'https'
        if port == 80 or port <= 10 or port >= 65535:
            port = 443
        host = strip(host)
        if host is not None:
            if host.startswith('*.'):
                host = host[2:]
            if host.startswith('%.'):
                host = host[2:]
        return scheme, host, port

    target_db = {}
    already_assoc = set()
    duplicates = []

    bind = op.get_bind()
    sess = BaseSession(bind=bind)
    it = sess.query(DbWatchTarget).yield_per(1000)
    for rec in it:
        ck = target_key(rec)
        rec_assoc = rec

        if ck in target_db:
            rec_assoc = target_db[ck]
            duplicates.append(rec.id)
        else:
            target_db[ck] = rec
            rec.scan_scheme = ck[0]
            rec.scan_host = ck[1]
            rec.scan_port = ck[2]

        if rec.user_id is None:
            continue

        cur_assoc_key = rec_assoc.id, rec.user_id
        if cur_assoc_key in already_assoc:
            print('already assoc: %s' % (cur_assoc_key, ))
            continue
        already_assoc.add(cur_assoc_key)

        assoc = DbWatchAssoc()
        assoc.scan_type = 1
        assoc.created_at = rec_assoc.created_at
        assoc.updated_at = rec_assoc.updated_at
        assoc.scan_periodicity = rec_assoc.scan_periodicity
        assoc.user_id = rec.user_id  # actual record!
        assoc.watch_id = rec_assoc.id
        sess.add(assoc)
    sess.commit()

    # remove duplicates
    if len(duplicates) > 0:
        sess.query(DbWatchTarget).filter(DbWatchTarget.id.in_(list(duplicates))) \
            .delete(synchronize_session='fetch')
        sess.commit()
        print('Removed %s duplicates %s' % (len(duplicates), duplicates))
Ejemplo n.º 34
0
 def remove(self, db_session: Session, *, id: int) -> ModelType:
     obj = db_session.query(self.model).get(id)
     db_session.delete(obj)
     db_session.commit()
     return obj
Ejemplo n.º 35
0
def peer_ip_address_set(sess: Session, peer: schemas.WGPeer) -> schemas.WGPeer:
    db_peer: models.WGPeer = peer_query_get_by_address(sess, peer.address, peer.server).one()
    db_peer.address = peer.address
    sess.add(db_peer)
    sess.commit()
    return peer.from_orm(db_peer)
Ejemplo n.º 36
0
def create_user_item(db: Session, item: schemas.ItemCreate, user_id: int):
    db_item = models.Item(**item.dict(), owner_id=user_id)
    db.add(db_item)
    db.commit()
    db.refresh(db_item)
    return db_item
Ejemplo n.º 37
0
            engineer_name="engineer1",
            primary_language="java",
        ),
        Person(name="joesmith"),
        Engineer(
            name="wally",
            status="CGG",
            engineer_name="engineer2",
            primary_language="python",
        ),
        Manager(name="jsmith", status="ABA", manager_name="manager2"),
    ],
)
session.add(c)

session.commit()

c = session.query(Company).get(1)
for e in c.employees:
    print(e, inspect(e).key, e.company)
assert set([e.name for e in c.employees]) == set(
    ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"])
print("\n")

dilbert = session.query(Person).filter_by(name="dilbert").one()
dilbert2 = session.query(Engineer).filter_by(name="dilbert").one()
assert dilbert is dilbert2

dilbert.engineer_name = "hes dilbert!"

session.commit()
Ejemplo n.º 38
0
def delete_player(db: Session, player_id):
    db_player = db.query(Player).filter(Player.id == player_id).first()
    db_player.name = f"rnd-{uuid.uuid4()}"
    db.add(db_player)
    db.commit()
Ejemplo n.º 39
0
def delete_product(db : Session,product : schemas.Product):
    db.delete(product)
    db.commit()
    return "Successfully deleted"
Ejemplo n.º 40
0
def create_travels(db: Session, travels: List[schemas.TravelCreate]):
    db_travels = [models.Travel(**travel.dict()) for travel in travels]
    db.add_all(db_travels)
    db.commit()
    return db_travels
Ejemplo n.º 41
0
class Librarian:
    """ Librarian Class (Databasing Utilities for Miners) """

    # TODO: Place connection string in environments file
    PSQL_CONN = "postgresql+psycopg2://test:test@localhost:5432/cryptkeeper_raw"

    def __init__(self, schema):
        self.ENGINE = create_engine(self.PSQL_CONN)
        self.SESSION = Session(bind=self.ENGINE)
        self.SCHEMA = schema

    # Public Methods
    def bulkInsert(self, items):
        """
    Inserts a list of items into the table of supplied schema, implementing
    `bulk_insert_mappings`. Unlike other methods, `bulkInsert` akes a list
    of dictionaries and performs the mapping to the schema as opposed to
    taking items that have already been mapped according to the schema
    thereunto pertaining.
    """
        try:
            self.SESSION.bulk_insert_mappings(self.SCHEMA, items)
            self.SESSION.commit()

        except SQLAlchemyError as err:
            self.SESSION.rollback()
            print("Error with bulk insert: %s" % (err))

    def bulkInsertDoNothingOnConflict(self, items, unique_fields):
        """
    Either inserts items into the table of supplied schema or does nothing.
    Requires items to be mapped prior to being passed to this method.
    """
        try:
            statement = postgresql \
              .insert(self.SCHEMA.__table__) \
              .values(items) \
              .on_conflict_do_nothing(index_elements = unique_fields)

            self.SESSION.execute(statement)
            self.SESSION.commit()

        except SQLAlchemyError as err:
            self.SESSION.rollback()
            print("Error with bulk insert: %s" % (err))

    def bulkUpsert(self, items, unique_fields):
        """
    Either inserts items into the table of supplied schema or updates
    conflicting field. Requires items to be mapped prior to being passed to
    this method.
    """
        try:
            for item in items:
                statement = postgresql \
                  .insert(self.SCHEMA.__table__) \
                  .values(item) \
                  .on_conflict_do_update(
                    index_elements = unique_fields,
                    set_ = item
                  )

                self.SESSION.execute(statement)

            self.SESSION.commit()

        except SQLAlchemyError as err:
            self.SESSION.rollback()
            print("Error with bulk upsert: %s" % (err))

    def getLastDaysEntries(self):
        """ Select all entries added in the last day """
        one_day_ago = datetime.now() - timedelta(days=1)

        try:
            return self.SESSION.query(self.SCHEMA).filter(
                and_(self.SCHEMA.created >= one_day_ago)).all()

        except SQLAlchemyError as err:
            print("Error with select query: %s" % (err))
Ejemplo n.º 42
0
    def update(self,
               db: Session,
               *,
               model_id: int = None,
               obj: ModelType = None,
               query: Dict[str, Any] = None,
               data: Union[UpdateSchemaType, Dict[str, Any]],
               is_return_obj: bool = False,
               is_transaction: bool = False) -> ModelType:
        """
        单个对象更新

        更新有两种方式:
        方式一:
                传入需要更新的模型和更新的数据。
        方式二:
                传入id和更新的数据
        注意:
            如果传入模型来进行更新,则'is_return_obj=False'失效,返回更新后的模型

        :param db:
        :param model_id:    模型ID
        :param obj:         模型对象
        :param query:       模型查询参数
        :param data:        需要更新的数据
        :param is_return_obj:   是否需要返回模型数据,默认为False,只返回更新成功的行数
        :param is_transaction:  是否开启事务功能

        :return: update_count or obj or None
        """
        if not any((model_id, obj, query)):
            raise ValueError('At least one of [model_id、query、obj] exists')

        if isinstance(data, dict):
            update_data = data
        else:
            update_data = data.dict(exclude_unset=True)

        if not is_return_obj and not obj:
            if model_id:
                update_count = db.query(self.model).filter(
                    self.model.id == model_id).update(update_data)
            else:
                update_count = db.query(
                    self.model).filter_by(**query).update(update_data)

            if not is_transaction:
                db.commit()
            return update_count
        else:
            if not obj:
                if model_id:
                    obj = self.get(db, model_id)
                else:
                    obj = self.get_one(db, **query)
            if obj:
                obj_data = jsonable_encoder(obj)

                for field in obj_data:
                    if field in update_data:
                        setattr(obj, field, update_data[field])
                db.add(obj)
                if not is_transaction:
                    db.commit()
                    db.refresh(obj)
                return obj
Ejemplo n.º 43
0
def create_my_story(db: Session, my_story: schemas.MyStoryCreate):
    db_my_story = models.MyStory(**my_story.dict())
    db.add(db_my_story)
    db.commit()

    return db_my_story
Ejemplo n.º 44
0
 def InsertWarehouseEmail(self, binaryBody):
     session = Session(self.Engine)
     entity = self.WarehouseEmailSource(Content=binaryBody)
     session.add(entity)
     session.commit()
     session.close()
Ejemplo n.º 45
0
def update_latest_my_story(db: Session, story: schemas.Story, my_story):
    story.latest_my_story = my_story
    db.add(story)
    db.commit()