Exemple #1
0
    schema.Column('phoneticTranscription', types.Unicode(255)),
    # narrow phonetic transcription -- optional
    schema.Column('narrowPhoneticTranscription', types.Unicode(255)),
    schema.Column('morphemeBreak', types.Unicode(255)),
    schema.Column('morphemeGloss', types.Unicode(255)),
    schema.Column('comments', types.UnicodeText()),
    schema.Column('speakerComments', types.UnicodeText()),
    schema.Column('context',
                  types.UnicodeText()),  # describing context of utterance

    # Forced choice textual values
    schema.Column('grammaticality', types.Unicode(255)),

    # Temporal values: only dateElicited is consciously enterable by the user
    schema.Column('dateElicited', types.Date()),
    schema.Column('datetimeEntered', types.DateTime()),
    schema.Column('datetimeModified', types.DateTime(), default=now),

    # syntacticCategoryString: OLD-generated value
    schema.Column('syntacticCategoryString', types.Unicode(255)),

    # morphemeBreakIDs and morphemeGlossIDs: OLD-generated values
    schema.Column('morphemeBreakIDs', types.Unicode(1023)),
    schema.Column('morphemeGlossIDs', types.Unicode(1023)),

    # breakGlossCategory: OLD-generated value, e.g., 'chien|dog|N-s|PL|NUM'
    schema.Column('breakGlossCategory', types.Unicode(1023)),

    # A Form can have only one each of elicitor, enterer and verifier, but each
    # of these can have more than one form
    # Form-User = Many-to-One
Exemple #2
0
    def reflecttable(self, connection, table):
        #TODO: map these better
        column_func = {
            14 : lambda r: sqltypes.String(r['FLEN']), # TEXT
            7  : lambda r: sqltypes.Integer(), # SHORT
            8  : lambda r: sqltypes.Integer(), # LONG
            9  : lambda r: sqltypes.Float(), # QUAD
            10 : lambda r: sqltypes.Float(), # FLOAT
            27 : lambda r: sqltypes.Float(), # DOUBLE
            35 : lambda r: sqltypes.DateTime(), # TIMESTAMP
            37 : lambda r: sqltypes.String(r['FLEN']), # VARYING
            261: lambda r: sqltypes.TEXT(), # BLOB
            40 : lambda r: sqltypes.Char(r['FLEN']), # CSTRING
            12 : lambda r: sqltypes.Date(), # DATE
            13 : lambda r: sqltypes.Time(), # TIME
            16 : lambda r: sqltypes.Numeric(precision=r['FPREC'], length=r['FSCALE'] * -1)  #INT64
            }
        tblqry = """
        SELECT DISTINCT R.RDB$FIELD_NAME AS FNAME,
                  R.RDB$NULL_FLAG AS NULL_FLAG,
                  R.RDB$FIELD_POSITION,
                  F.RDB$FIELD_TYPE AS FTYPE,
                  F.RDB$FIELD_SUB_TYPE AS STYPE,
                  F.RDB$FIELD_LENGTH AS FLEN,
                  F.RDB$FIELD_PRECISION AS FPREC,
                  F.RDB$FIELD_SCALE AS FSCALE
        FROM RDB$RELATION_FIELDS R
             JOIN RDB$FIELDS F ON R.RDB$FIELD_SOURCE=F.RDB$FIELD_NAME
        WHERE F.RDB$SYSTEM_FLAG=0 and R.RDB$RELATION_NAME=?
        ORDER BY R.RDB$FIELD_POSITION"""
        keyqry = """
        SELECT SE.RDB$FIELD_NAME SENAME
        FROM RDB$RELATION_CONSTRAINTS RC
             JOIN RDB$INDEX_SEGMENTS SE
               ON RC.RDB$INDEX_NAME=SE.RDB$INDEX_NAME
        WHERE RC.RDB$CONSTRAINT_TYPE=? AND RC.RDB$RELATION_NAME=?"""
        fkqry = """
        SELECT RC.RDB$CONSTRAINT_NAME CNAME,
               CSE.RDB$FIELD_NAME FNAME,
               IX2.RDB$RELATION_NAME RNAME,
               SE.RDB$FIELD_NAME SENAME
        FROM RDB$RELATION_CONSTRAINTS RC
             JOIN RDB$INDICES IX1
               ON IX1.RDB$INDEX_NAME=RC.RDB$INDEX_NAME
             JOIN RDB$INDICES IX2
               ON IX2.RDB$INDEX_NAME=IX1.RDB$FOREIGN_KEY
             JOIN RDB$INDEX_SEGMENTS CSE
               ON CSE.RDB$INDEX_NAME=IX1.RDB$INDEX_NAME
             JOIN RDB$INDEX_SEGMENTS SE
               ON SE.RDB$INDEX_NAME=IX2.RDB$INDEX_NAME AND SE.RDB$FIELD_POSITION=CSE.RDB$FIELD_POSITION
        WHERE RC.RDB$CONSTRAINT_TYPE=? AND RC.RDB$RELATION_NAME=?
        ORDER BY SE.RDB$INDEX_NAME, SE.RDB$FIELD_POSITION"""

        # get primary key fields
        c = connection.execute(keyqry, ["PRIMARY KEY", table.name.upper()])
        pkfields =[r['SENAME'] for r in c.fetchall()]

        # get all of the fields for this table

        def lower_if_possible(name):
            # Remove trailing spaces: FB uses a CHAR() type,
            # that is padded with spaces
            name = name.rstrip()
            # If its composed only by upper case chars, use
            # the lowered version, otherwise keep the original
            # (even if stripped...)
            lname = name.lower()
            if lname.upper() == name and not ' ' in name:
                return lname
            return name

        c = connection.execute(tblqry, [table.name.upper()])
        row = c.fetchone()
        if not row:
            raise exceptions.NoSuchTableError(table.name)

        while row:
            name = row['FNAME']
            args = [lower_if_possible(name)]

            kw = {}
            # get the data types and lengths
            args.append(column_func[row['FTYPE']](row))

            # is it a primary key?
            kw['primary_key'] = name in pkfields

            table.append_column(schema.Column(*args, **kw))
            row = c.fetchone()

        # get the foreign keys
        c = connection.execute(fkqry, ["FOREIGN KEY", table.name.upper()])
        fks = {}
        while True:
            row = c.fetchone()
            if not row: break

            cname = lower_if_possible(row['CNAME'])
            try:
                fk = fks[cname]
            except KeyError:
                fks[cname] = fk = ([], [])
            rname = lower_if_possible(row['RNAME'])
            schema.Table(rname, table.metadata, autoload=True, autoload_with=connection)
            fname = lower_if_possible(row['FNAME'])
            refspec = rname + '.' + lower_if_possible(row['SENAME'])
            fk[0].append(fname)
            fk[1].append(refspec)

        for name,value in fks.iteritems():
            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
Exemple #3
0
def test_encode_datetime():
    now = datetime.datetime.now()
    e = encode(types.DateTime())
    d = decode(types.DateTime())
    # microseconds are lost, but that's ok
    assert now.timetuple()[:6] == d(e(now)).timetuple()[:6]
Exemple #4
0
class SurveySession(BaseObject):
    """Information about a user's session."""

    __tablename__ = "session"

    id = schema.Column(types.Integer(), primary_key=True, autoincrement=True)
    brand = schema.Column(types.String(64))
    account_id = schema.Column(
        types.Integer(),
        schema.ForeignKey(Account.id, onupdate="CASCADE", ondelete="CASCADE"),
        nullable=False,
        index=True,
    )
    last_modifier_id = schema.Column(
        types.Integer(),
        schema.ForeignKey(Account.id, onupdate="CASCADE", ondelete="CASCADE"),
        nullable=True,
        index=False,
    )
    last_publisher_id = schema.Column(
        types.Integer(),
        schema.ForeignKey(Account.id, onupdate="CASCADE", ondelete="CASCADE"),
        nullable=True,
        index=False,
    )
    group_id = schema.Column(
        types.Unicode(32),
        schema.ForeignKey("group.group_id"),
    )
    title = schema.Column(types.Unicode(512))
    created = schema.Column(
        types.DateTime,
        nullable=False,
        default=functions.now(),
    )
    modified = schema.Column(
        types.DateTime,
        nullable=False,
        default=functions.now(),
    )
    refreshed = schema.Column(
        types.DateTime,
        nullable=False,
        default=functions.now(),
    )

    published = schema.Column(
        types.DateTime,
        nullable=True,
        default=None,
    )

    archived = schema.Column(
        types.DateTime(timezone=True),
        nullable=True,
        default=None,
    )

    completion_percentage = schema.Column(
        types.Integer,
        nullable=True,
        default=0,
    )

    zodb_path = schema.Column(types.String(512), nullable=False)

    report_comment = schema.Column(types.UnicodeText())

    account = orm.relation(
        Account,
        backref=orm.backref(
            "sessions",
            order_by=modified,
            cascade="all, delete, delete-orphan",
        ),
        foreign_keys=[account_id],
    )
    last_modifier = orm.relation(
        Account,
        foreign_keys=[last_modifier_id],
    )
    last_publisher = orm.relation(
        Account,
        foreign_keys=[last_publisher_id],
    )

    group = orm.relation(
        Group,
        backref=orm.backref("sessions",
                            order_by=modified,
                            cascade="all, delete, delete-orphan"),
    )

    migrated = schema.Column(
        types.DateTime,
        nullable=False,
        default=functions.now(),
    )

    # Allow this class to be subclassed in other projects
    __mapper_args__ = {
        "polymorphic_identity": "euphorie",
        "polymorphic_on": brand,
        "with_polymorphic": "*",
    }

    @property
    def is_archived(self):
        archived = self.archived
        if not archived:
            return False
        return archived <= localized_now()

    @property
    def review_state(self):
        """Check if it the published column.
        If it has return 'published' otherwise return 'private'
        """
        return "published" if self.published else "private"

    def hasTree(self):
        return bool(
            Session.query(SurveyTreeItem).filter(
                SurveyTreeItem.session == self).count())

    def reset(self):
        Session.query(SurveyTreeItem).filter(
            SurveyTreeItem.session == self).delete()
        self.created = self.modified = datetime.datetime.now()

    def touch(self):
        self.last_modifier = get_current_account()
        self.modified = datetime.datetime.now()

    def refresh_survey(self, survey=None):
        """Mark the session with the current date to indicate that is has been
        refreshed with the latest version of the Survey (from Zope).
        If survey is passed, update all titles in the tree, based on the CMS
        version of the survey, i.e. update all titles of modules and risks.
        Those are used in the navigation. If a title change is the only change
        in the CMS, the survey session is not re-created. Therefore this
        method ensures that the titles are updated where necessary.
        """
        if survey:
            query = Session.query(SurveyTreeItem).filter(
                SurveyTreeItem.session_id == self.id)
            tree = query.all()
            for item in tree:
                if item.zodb_path.find("custom-risks") >= 0:
                    continue
                zodb_item = survey.restrictedTraverse(
                    item.zodb_path.split("/"), None)
                if zodb_item and zodb_item.title != item.title:
                    item.title = zodb_item.title
        self.refreshed = datetime.datetime.now()

    def addChild(self, item):
        sqlsession = Session()
        query = (sqlsession.query(SurveyTreeItem.path).filter(
            SurveyTreeItem.session_id == self.id).filter(
                SurveyTreeItem.depth == 1).order_by(
                    SurveyTreeItem.path.desc()))
        last = query.first()
        if not last:
            index = 1
        else:
            index = int(last[0][-3:]) + 1

        item.session = self
        item.depth = 1
        item.path = "%03d" % index
        item.parent_id = None
        sqlsession.add(item)
        self.touch()
        return item

    def children(self, filter=None):
        query = (Session.query(SurveyTreeItem).filter(
            SurveyTreeItem.session_id == self.id).filter(
                SurveyTreeItem.depth == 1))
        if filter is not None:
            query = query.filter(filter)
        return query.order_by(SurveyTreeItem.path)

    def copySessionData(self, other):
        """Copy all user data from another session to this one."""
        session = Session()

        # Copy all tree data to the new session (skip_children and postponed)
        old_tree = orm.aliased(SurveyTreeItem, name="old_tree")
        in_old_tree = sql.and_(
            old_tree.session_id == other.id,
            SurveyTreeItem.zodb_path == old_tree.zodb_path,
            SurveyTreeItem.profile_index == old_tree.profile_index,
        )
        skip_children = sql.select([old_tree.skip_children],
                                   in_old_tree).limit(1)
        postponed = sql.select([old_tree.postponed], in_old_tree).limit(1)
        new_items = (session.query(SurveyTreeItem).filter(
            SurveyTreeItem.session == self).filter(
                sql.exists(sql.select([old_tree.id]).where(in_old_tree))))
        new_items.update(
            {
                "skip_children": skip_children,
                "postponed": postponed
            },
            synchronize_session=False,
        )

        # Mandatory modules must have skip_children=False. It's possible that
        # the module was optional with skip_children=True and now after the
        # update it's mandatory. So we must check and correct.
        # In case a risk was marked as "always present", be sure its
        # identification gets set to 'no'
        preset_to_no = []
        survey = getSite()["client"].restrictedTraverse(self.zodb_path)
        for item in new_items.all():
            if item.type == "risk":
                if item.identification == "no":
                    preset_to_no.append(item.risk_id)

            elif item.type == "module":
                module = survey.restrictedTraverse(item.zodb_path.split("/"))
                if not module.optional:
                    item.skip_children = False

        # Copy all risk data to the new session
        # This triggers a "Only update via a single table query is currently
        # supported" error with SQLAlchemy 0.6.6
        # old_risk = orm.aliased(Risk.__table__, name='old_risk')
        # is_old_risk = sql.and_(in_old_tree, old_tree.id == old_risk.id)
        # identification = sql.select([old_risk.identification], is_old_risk)
        # new_risks = session.query(Risk)\
        #         .filter(Risk.session == self)\
        #         .filter(sql.exists(
        #             sql.select([SurveyTreeItem.id]).where(sql.and_(
        #                     SurveyTreeItem.id == Risk.id,
        #                     sql.exists([old_tree.id]).where(sql.and_(
        #                         in_old_tree, old_tree.type == 'risk'))))))
        # new_risks.update({'identification': identification},
        #         synchronize_session=False)

        skip_preset_to_no_clause = ""
        if len(preset_to_no):
            skip_preset_to_no_clause = "old_risk.risk_id not in %s AND" % (str(
                [str(x)
                 for x in preset_to_no]).replace("[", "(").replace("]", ")"))
        statement = """\
        UPDATE RISK
        SET identification = old_risk.identification,
            frequency = old_risk.frequency,
            effect = old_risk.effect,
            probability = old_risk.probability,
            priority = old_risk.priority,
            existing_measures = old_risk.existing_measures,
            comment = old_risk.comment
        FROM risk AS old_risk JOIN tree AS old_tree ON old_tree.id=old_risk.id, tree
        WHERE tree.id=risk.id AND
              %(skip_preset_to_no_clause)s
              tree.session_id=%(new_sessionid)s AND
              old_tree.session_id=%(old_sessionid)s AND
              old_tree.zodb_path=tree.zodb_path AND
              old_tree.profile_index=tree.profile_index;
        """ % dict(  # noqa: E501
            old_sessionid=other.id,
            new_sessionid=self.id,
            skip_preset_to_no_clause=skip_preset_to_no_clause,
        )
        session.execute(statement)

        statement = """\
        INSERT INTO action_plan (risk_id, action_plan, prevention_plan, action,
                                        requirements, responsible, budget, plan_type,
                                        planning_start, planning_end,
                                        solution_id, used_in_training)
               SELECT new_tree.id,
                      action_plan.action_plan,
                      action_plan.prevention_plan,
                      action_plan.action,
                      action_plan.requirements,
                      action_plan.responsible,
                      action_plan.budget,
                      action_plan.plan_type,
                      action_plan.planning_start,
                      action_plan.planning_end,
                      action_plan.solution_id,
                      action_plan.used_in_training
               FROM action_plan JOIN risk ON action_plan.risk_id=risk.id
                                JOIN tree ON tree.id=risk.id,
                    tree AS new_tree
               WHERE tree.session_id=%(old_sessionid)d AND
                     new_tree.session_id=%(new_sessionid)d AND
                     tree.zodb_path=new_tree.zodb_path AND
                     tree.profile_index=new_tree.profile_index;
            """ % {
            "old_sessionid": other.id,
            "new_sessionid": self.id,
        }
        session.execute(statement)

        # Copy over previous session metadata. Specifically, we don't want to
        # create a new modification timestamp, just because the underlying
        # survey was updated.
        statement = """\
        UPDATE session
        SET
            modified = old_session.modified,
            created = old_session.created,
            last_modifier_id = old_session.last_modifier_id
        FROM session as old_session
        WHERE
            old_session.id=%(old_sessionid)d AND
            session.id=%(new_sessionid)d
        """ % {
            "old_sessionid": other.id,
            "new_sessionid": self.id,
        }
        session.execute(statement)

        session.query(Company).filter(Company.session == other).update(
            {"session_id": self.id}, synchronize_session=False)

    @classmethod
    def get_account_filter(cls, account=None):
        """Filter only the sessions for the given account

        :param acount: True means current account.
            A falsish value means do not filter.
            Otherwise try to interpret the user input:
            a string or an int means the account_id should be that value,
            an object account will be used to extract the account id,
            from an iterable we will try to extract the account ids
        """
        if account is True:
            account = get_current_account()

        if isinstance(account, Account):
            account = account.id

        if not account:
            return False

        if isinstance(account, (int, six.string_types)):
            return cls.account_id == account

        try:
            account_ids = {getattr(item, "id", item) for item in account}
        except TypeError:
            log.error("Cannot understand the account parameter: %r", account)
            raise

        account_ids = {
            item
            for item in account_ids
            if item and isinstance(item, (int, six.string_types))
        }
        if not account_ids:
            return False

        if len(account_ids) == 1:
            for account_id in account_ids:
                return cls.get_account_filter(account_id)

        return cls.account_id.in_(account_ids)

    @classmethod
    def get_group_filter(cls, group=None):
        """Filter only the sessions for the given group

        :param group: True means the current account's group.
            A falsish value means do not filter.
            Otherwise try to interpret the user input:
            a string or an int means the group_id should be that value,
            an object group will be used to extract the group id,
            and from an iterable we will try to extract the group ids
        """
        if group is True:
            group = getattr(get_current_account(), "group_id", None)

        if isinstance(group, Group):
            group = group.group_id

        if not group:
            return False

        if isinstance(group, (int, six.string_types)):
            return cls.group_id == group

        try:
            group_ids = {getattr(item, "group_id", item) for item in group}
        except TypeError:
            log.error("Cannot understand the group parameter: %r", group)
            raise

        group_ids = {
            item
            for item in group_ids
            if item and isinstance(item, (int, six.string_types))
        }
        if not group_ids:
            return False

        if len(group_ids) == 1:
            for group_id in group_ids:
                return cls.get_group_filter(group_id)

        return cls.group_id.in_(group_ids)

    @classmethod
    def get_archived_filter(cls):
        """Filter sessions that are archived"""
        return sql.or_(
            cls.archived >= localized_now(),
            cls.archived == None  # noqa: E711
        )

    @classmethod
    def _get_context_tools(cls, context):
        """Return the set of tools we can find under this context"""
        if not context:
            return set()

        # Check the path relative to the client folder
        if context.portal_type == "Plone Site":
            context = context.client

        if context.portal_type == "euphorie.survey":
            return {context}

        portal_type_filter = {
            "portal_type": [
                "euphorie.clientcountry",
                "euphorie.clientsector",
                "euphorie.survey",
            ]
        }

        surveys = set()

        def _add_survey(container):
            for obj in container.listFolderContents(portal_type_filter):
                if obj.portal_type == "euphorie.survey":
                    surveys.add(obj)
                else:
                    _add_survey(obj)

        _add_survey(context)
        return surveys

    @classmethod
    def get_context_filter(cls, context):
        """Filter sessions under this context using the zodb_path column"""
        surveys = cls._get_context_tools(context)
        if not surveys:
            return False

        return cls.zodb_path.in_({
            safe_unicode("/".join(survey.getPhysicalPath()[-3:]))
            for survey in surveys
        })

    @property
    def tool(self):
        client = api.portal.get().client
        return client.restrictedTraverse(str(self.zodb_path), None)

    @property
    def traversed_session(self):
        return self.tool.restrictedTraverse("++session++%s" % self.id)

    @property
    def country(self):
        return str(self.zodb_path).split("/")[0]
Exemple #5
0
 def test_setup_epoch(self):
     column_info = {"type": types.DateTime()}
     bulkdata.setup_epoch(mock.Mock(), mock.Mock(), column_info)
     self.assertIsInstance(column_info["type"], bulkdata.EpochType)
Exemple #6
0
class Project(Base):
    __tablename__ = "projects"
    #__mapper_args__ = dict(order_by="date desc")

    id = Column(types.Integer, primary_key=True)
    #leader_id = Column(types.Integer, ForeignKey('user.id'))

    # If initial_round_size > 0, this will point to
    # the entry in the citations_tasks table that
    # maps to the initial assignment associated with
    # this project.
    initial_assignment_id = Column(types.Integer)

    name = Column(types.Unicode(255))
    description = Column(types.Unicode(10000))

    # This is used to identify the project when
    # the leader invites others to the project
    code = Column(types.Unicode(10), index=True, unique=True)

    # `single', `double', or `advanced'
    screening_mode = Column(types.Unicode(50))

    # True := tags are private
    # False := tags are public
    tag_privacy = Column(types.Boolean)

    # the number of labels to be procured for each abstract
    num_labels_thus_far = Column(types.Integer)

    # basically, the AI criteria (or random)
    sort_by = Column(types.Unicode(255))

    # If >0, this represents a fixed set of
    # citations that will (potentially) be
    # screened by everyone on the project
    initial_round_size = Column(types.Integer)

    # Minimum and maximum number of citations a user should screen for this project
    min_citations = Column(types.Integer)
    max_citations = Column(types.Integer)

    # Bookkeeping
    date_created = Column(types.DateTime())
    date_modified = Column(types.DateTime())

    priorities = relationship('Priority',
                              order_by='Priority.id',
                              backref='project')
    citations = relationship('Citation',
                             order_by='Citation.id',
                             backref='project')
    assignments = relationship('Assignment',
                               order_by='Assignment.id',
                               backref='project')
    labels = relationship('Label', order_by='Label.id', backref='project')
    members = relationship('User',
                           secondary=users_projects_table,
                           backref='member_of_projects')
    leaders = relationship('User',
                           secondary=projects_leaders_table,
                           backref='leader_of_projects')
Exemple #7
0
class BQDialect(default.DefaultDialect):
    colspecs = {
        types.Unicode: BQString,
        types.Integer: BQInteger,
        types.SmallInteger: BQInteger,
        types.Numeric: BQFloat,
        types.Float: BQFloat,
        types.DateTime: BQTimestamp,
        types.Date: BQTimestamp,
        types.String: BQString,
        types.LargeBinary: BQBytes,
        types.Boolean: BQBoolean,
        types.Text: BQString,
        types.CHAR: BQString,
        types.TIMESTAMP: BQTimestamp,
        types.VARCHAR: BQString
    }

    __TYPE_MAPPINGS = {
        'TIMESTAMP': types.DateTime(),
        'STRING': types.String(),
        'FLOAT': types.Float(),
        'INTEGER': types.Integer(),
        'BOOLEAN': types.Boolean()
    }

    name = 'bigquery'
    driver = 'bq1'
    poolclass = pool.SingletonThreadPool
    statement_compiler = BQSQLCompiler
    ddl_compiler = BQDDLCompiler
    preparer = BQIdentifierPreparer
    execution_ctx_cls = BQExecutionContext

    supports_alter = False
    supports_unicode_statements = True
    supports_sane_multi_rowcount = False
    supports_sane_rowcount = False
    supports_sequences = False
    supports_native_enum = False

    positional = False
    paramstyle = 'named'

    default_sequence_base = 0
    default_schema_name = None

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def __init__(self, **kw):
        #
        # Create a dialect object
        #
        super(BQDialect, self).__init__(**kw)

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def create_connect_args(self, url):
        #
        # This function recovers connection parameters from the connection string
        #
        return [], {}

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def initialize(self, connection):
        """disable all dialect initialization"""

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    @classmethod
    def dbapi(cls):
        return dbapi

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def do_execute(self, cursor, statement, parameters, context=None):
        cursor.execute(statement, parameters)

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def do_executemany(self, cursor, statement, parameters, context=None):
        cursor.executemany(statement, parameters)

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def get_schema_names(self, engine, **kw):
        return engine.connect().connection.get_schema_names()

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def get_view_names(self, connection, schema=None, **kw):
        raise NotImplementedError()

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def get_view_definition(self, connection, viewname, schema=None, **kw):
        raise NotImplementedError()

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def has_table(self, connection, table_name, schema=None):
        return table_name in connection.connection.get_table_names()

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    @reflection.cache
    def get_table_names(self, engine, schema=None, **kw):
        return engine.connect().connection.get_table_names()

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def get_columns(self, engine, table_name, schema=None, **kw):
        cols = engine.connect().connection.get_columns(table_name)

        get_coldef = lambda x, y: {
            "name": x,
            "type": BQDialect.__TYPE_MAPPINGS.get(y, types.Binary()),
            "nullable": True,
            "default": None
        }

        return [get_coldef(*col) for col in cols]

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def get_primary_keys(self, engine, table_name, schema=None, **kw):
        return []

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def get_foreign_keys(self, engine, table_name, schema=None, **kw):
        return []

    #  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -
    def get_indexes(self, connection, table_name, schema=None, **kw):
        return []
Exemple #8
0
    def reflecttable(self, connection, table, include_columns):
        #TODO: map these better
        column_func = {
            14 : lambda r: sqltypes.String(r['FLEN']), # TEXT
            7  : lambda r: sqltypes.Integer(), # SHORT
            8  : lambda r: r['FPREC']==0 and sqltypes.Integer() or sqltypes.Numeric(precision=r['FPREC'], length=r['FSCALE'] * -1),  #INT or NUMERIC
            9  : lambda r: sqltypes.Float(), # QUAD
            10 : lambda r: sqltypes.Float(), # FLOAT
            27 : lambda r: sqltypes.Float(), # DOUBLE
            35 : lambda r: sqltypes.DateTime(), # TIMESTAMP
            37 : lambda r: sqltypes.String(r['FLEN']), # VARYING
            261: lambda r: sqltypes.TEXT(), # BLOB
            40 : lambda r: sqltypes.Char(r['FLEN']), # CSTRING
            12 : lambda r: sqltypes.Date(), # DATE
            13 : lambda r: sqltypes.Time(), # TIME
            16 : lambda r: sqltypes.Numeric(precision=r['FPREC'], length=r['FSCALE'] * -1)  #INT64
            }
        tblqry = """
        SELECT DISTINCT R.RDB$FIELD_NAME AS FNAME,
                  R.RDB$NULL_FLAG AS NULL_FLAG,
                  R.RDB$FIELD_POSITION,
                  F.RDB$FIELD_TYPE AS FTYPE,
                  F.RDB$FIELD_SUB_TYPE AS STYPE,
                  F.RDB$FIELD_LENGTH AS FLEN,
                  F.RDB$FIELD_PRECISION AS FPREC,
                  F.RDB$FIELD_SCALE AS FSCALE
        FROM RDB$RELATION_FIELDS R
             JOIN RDB$FIELDS F ON R.RDB$FIELD_SOURCE=F.RDB$FIELD_NAME
        WHERE F.RDB$SYSTEM_FLAG=0 and R.RDB$RELATION_NAME=?
        ORDER BY R.RDB$FIELD_POSITION"""
        keyqry = """
        SELECT SE.RDB$FIELD_NAME SENAME
        FROM RDB$RELATION_CONSTRAINTS RC
             JOIN RDB$INDEX_SEGMENTS SE
               ON RC.RDB$INDEX_NAME=SE.RDB$INDEX_NAME
        WHERE RC.RDB$CONSTRAINT_TYPE=? AND RC.RDB$RELATION_NAME=?"""
        fkqry = """
        SELECT RC.RDB$CONSTRAINT_NAME CNAME,
               CSE.RDB$FIELD_NAME FNAME,
               IX2.RDB$RELATION_NAME RNAME,
               SE.RDB$FIELD_NAME SENAME
        FROM RDB$RELATION_CONSTRAINTS RC
             JOIN RDB$INDICES IX1
               ON IX1.RDB$INDEX_NAME=RC.RDB$INDEX_NAME
             JOIN RDB$INDICES IX2
               ON IX2.RDB$INDEX_NAME=IX1.RDB$FOREIGN_KEY
             JOIN RDB$INDEX_SEGMENTS CSE
               ON CSE.RDB$INDEX_NAME=IX1.RDB$INDEX_NAME
             JOIN RDB$INDEX_SEGMENTS SE
               ON SE.RDB$INDEX_NAME=IX2.RDB$INDEX_NAME AND SE.RDB$FIELD_POSITION=CSE.RDB$FIELD_POSITION
        WHERE RC.RDB$CONSTRAINT_TYPE=? AND RC.RDB$RELATION_NAME=?
        ORDER BY SE.RDB$INDEX_NAME, SE.RDB$FIELD_POSITION"""

        # get primary key fields
        c = connection.execute(keyqry, ["PRIMARY KEY", self._denormalize_name(table.name)])
        pkfields =[self._normalize_name(r['SENAME']) for r in c.fetchall()]

        # get all of the fields for this table
        c = connection.execute(tblqry, [self._denormalize_name(table.name)])

        found_table = False
        while True:
            row = c.fetchone()
            if row is None:
                break
            found_table = True

            name = self._normalize_name(row['FNAME'])
            if include_columns and name not in include_columns:
                continue
            args = [name]

            kw = {}
            # get the data types and lengths
            coltype = column_func.get(row['FTYPE'], None)
            if coltype is None:
                warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (str(row['FTYPE']), name)))
                coltype = sqltypes.NULLTYPE
            else:
                coltype = coltype(row)
            args.append(coltype)

            # is it a primary key?
            kw['primary_key'] = name in pkfields

            # is it nullable ?
            kw['nullable'] = not bool(row['NULL_FLAG'])

            table.append_column(schema.Column(*args, **kw))

        if not found_table:
            raise exceptions.NoSuchTableError(table.name)

        # get the foreign keys
        c = connection.execute(fkqry, ["FOREIGN KEY", self._denormalize_name(table.name)])
        fks = {}
        while True:
            row = c.fetchone()
            if not row: break

            cname = self._normalize_name(row['CNAME'])
            try:
                fk = fks[cname]
            except KeyError:
                fks[cname] = fk = ([], [])
            rname = self._normalize_name(row['RNAME'])
            schema.Table(rname, table.metadata, autoload=True, autoload_with=connection)
            fname = self._normalize_name(row['FNAME'])
            refspec = rname + '.' + self._normalize_name(row['SENAME'])
            fk[0].append(fname)
            fk[1].append(refspec)

        for name,value in fks.iteritems():
            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
Exemple #9
0
class BaseEngineSpec:  # pylint: disable=too-many-public-methods
    """Abstract class for database engine specific configurations"""

    engine = "base"  # str as defined in sqlalchemy.engine.engine
    engine_aliases: Optional[Tuple[str]] = None
    engine_name: Optional[
        str
    ] = None  # used for user messages, overridden in child classes
    _date_trunc_functions: Dict[str, str] = {}
    _time_grain_expressions: Dict[Optional[str], str] = {}
    column_type_mappings: Tuple[
        Tuple[
            Pattern[str],
            Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
            GenericDataType,
        ],
        ...,
    ] = (
        (
            re.compile(r"^smallint", re.IGNORECASE),
            types.SmallInteger(),
            GenericDataType.NUMERIC,
        ),
        (
            re.compile(r"^int.*", re.IGNORECASE),
            types.Integer(),
            GenericDataType.NUMERIC,
        ),
        (
            re.compile(r"^bigint", re.IGNORECASE),
            types.BigInteger(),
            GenericDataType.NUMERIC,
        ),
        (
            re.compile(r"^decimal", re.IGNORECASE),
            types.Numeric(),
            GenericDataType.NUMERIC,
        ),
        (
            re.compile(r"^numeric", re.IGNORECASE),
            types.Numeric(),
            GenericDataType.NUMERIC,
        ),
        (re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,),
        (
            re.compile(r"^smallserial", re.IGNORECASE),
            types.SmallInteger(),
            GenericDataType.NUMERIC,
        ),
        (
            re.compile(r"^serial", re.IGNORECASE),
            types.Integer(),
            GenericDataType.NUMERIC,
        ),
        (
            re.compile(r"^bigserial", re.IGNORECASE),
            types.BigInteger(),
            GenericDataType.NUMERIC,
        ),
        (
            re.compile(r"^string", re.IGNORECASE),
            types.String(),
            utils.GenericDataType.STRING,
        ),
        (
            re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
            UnicodeText(),
            utils.GenericDataType.STRING,
        ),
        (
            re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
            String(),
            utils.GenericDataType.STRING,
        ),
        (
            re.compile(r"^datetime", re.IGNORECASE),
            types.DateTime(),
            GenericDataType.TEMPORAL,
        ),
        (re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,),
        (
            re.compile(r"^timestamp", re.IGNORECASE),
            types.TIMESTAMP(),
            GenericDataType.TEMPORAL,
        ),
        (
            re.compile(r"^interval", re.IGNORECASE),
            types.Interval(),
            GenericDataType.TEMPORAL,
        ),
        (re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
        (
            re.compile(r"^bool.*", re.IGNORECASE),
            types.Boolean(),
            GenericDataType.BOOLEAN,
        ),
    )
    time_groupby_inline = False
    limit_method = LimitMethod.FORCE_LIMIT
    time_secondary_columns = False
    allows_joins = True
    allows_subqueries = True
    allows_alias_in_select = True
    allows_alias_in_orderby = True
    allows_sql_comments = True
    force_column_alias_quotes = False
    arraysize = 0
    max_column_name_length = 0
    try_remove_schema_from_table_name = True  # pylint: disable=invalid-name
    run_multiple_statements_as_one = False

    @classmethod
    def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
        """
        Each engine can implement and converge its own specific exceptions into
        Superset DBAPI exceptions

        Note: On python 3.9 this method can be changed to a classmethod property
        without the need of implementing a metaclass type

        :return: A map of driver specific exception to superset custom exceptions
        """
        return {}

    @classmethod
    def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
        """
        Get a superset custom DBAPI exception from the driver specific exception.

        Override if the engine needs to perform extra changes to the exception, for
        example change the exception message or implement custom more complex logic

        :param exception: The driver specific exception
        :return: Superset custom DBAPI exception
        """
        new_exception = cls.get_dbapi_exception_mapping().get(type(exception))
        if not new_exception:
            return exception
        return new_exception(str(exception))

    @classmethod
    def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
        return False

    @classmethod
    def get_engine(
        cls,
        database: "Database",
        schema: Optional[str] = None,
        source: Optional[str] = None,
    ) -> Engine:
        user_name = utils.get_username()
        return database.get_sqla_engine(
            schema=schema, nullpool=True, user_name=user_name, source=source
        )

    @classmethod
    def get_timestamp_expr(
        cls,
        col: ColumnClause,
        pdf: Optional[str],
        time_grain: Optional[str],
        type_: Optional[str] = None,
    ) -> TimestampExpression:
        """
        Construct a TimestampExpression to be used in a SQLAlchemy query.

        :param col: Target column for the TimestampExpression
        :param pdf: date format (seconds or milliseconds)
        :param time_grain: time grain, e.g. P1Y for 1 year
        :param type_: the source column type
        :return: TimestampExpression object
        """
        if time_grain:
            time_expr = cls.get_time_grain_expressions().get(time_grain)
            if not time_expr:
                raise NotImplementedError(
                    f"No grain spec for {time_grain} for database {cls.engine}"
                )
            if type_ and "{func}" in time_expr:
                date_trunc_function = cls._date_trunc_functions.get(type_)
                if date_trunc_function:
                    time_expr = time_expr.replace("{func}", date_trunc_function)
            if type_ and "{type}" in time_expr:
                date_trunc_function = cls._date_trunc_functions.get(type_)
                if date_trunc_function:
                    time_expr = time_expr.replace("{type}", type_)
        else:
            time_expr = "{col}"

        # if epoch, translate to DATE using db specific conf
        if pdf == "epoch_s":
            time_expr = time_expr.replace("{col}", cls.epoch_to_dttm())
        elif pdf == "epoch_ms":
            time_expr = time_expr.replace("{col}", cls.epoch_ms_to_dttm())

        return TimestampExpression(time_expr, col, type_=DateTime)

    @classmethod
    def get_time_grains(cls) -> Tuple[TimeGrain, ...]:
        """
        Generate a tuple of supported time grains.

        :return: All time grains supported by the engine
        """

        ret_list = []
        time_grains = builtin_time_grains.copy()
        time_grains.update(config["TIME_GRAIN_ADDONS"])
        for duration, func in cls.get_time_grain_expressions().items():
            if duration in time_grains:
                name = time_grains[duration]
                ret_list.append(TimeGrain(name, _(name), func, duration))
        return tuple(ret_list)

    @classmethod
    def _sort_time_grains(
        cls, val: Tuple[Optional[str], str], index: int
    ) -> Union[float, int, str]:
        """
        Return an ordered time-based value of a portion of a time grain
        for sorting
        Values are expected to be either None or start with P or PT
        Have a numerical value in the middle and end with
        a value for the time interval
        It can also start or end with epoch start time denoting a range
        i.e, week beginning or ending with a day
        """
        pos = {
            "FIRST": 0,
            "SECOND": 1,
            "THIRD": 2,
            "LAST": 3,
        }

        if val[0] is None:
            return pos["FIRST"]

        prog = re.compile(r"(.*\/)?(P|PT)([0-9\.]+)(S|M|H|D|W|M|Y)(\/.*)?")
        result = prog.match(val[0])

        # for any time grains that don't match the format, put them at the end
        if result is None:
            return pos["LAST"]

        second_minute_hour = ["S", "M", "H"]
        day_week_month_year = ["D", "W", "M", "Y"]
        is_less_than_day = result.group(2) == "PT"
        interval = result.group(4)
        epoch_time_start_string = result.group(1) or result.group(5)
        has_starting_or_ending = bool(len(epoch_time_start_string or ""))

        def sort_day_week() -> int:
            if has_starting_or_ending:
                return pos["LAST"]
            if is_less_than_day:
                return pos["SECOND"]
            return pos["THIRD"]

        def sort_interval() -> float:
            if is_less_than_day:
                return second_minute_hour.index(interval)
            return day_week_month_year.index(interval)

        # 0: all "PT" values should come before "P" values (i.e, PT10M)
        # 1: order values within the above arrays ("D" before "W")
        # 2: sort by numeric value (PT10M before PT15M)
        # 3: sort by any week starting/ending values
        plist = {
            0: sort_day_week(),
            1: pos["SECOND"] if is_less_than_day else pos["THIRD"],
            2: sort_interval(),
            3: float(result.group(3)),
        }

        return plist.get(index, 0)

    @classmethod
    def get_time_grain_expressions(cls) -> Dict[Optional[str], str]:
        """
        Return a dict of all supported time grains including any potential added grains
        but excluding any potentially disabled grains in the config file.

        :return: All time grain expressions supported by the engine
        """
        # TODO: use @memoize decorator or similar to avoid recomputation on every call
        time_grain_expressions = cls._time_grain_expressions.copy()
        grain_addon_expressions = config["TIME_GRAIN_ADDON_EXPRESSIONS"]
        time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
        denylist: List[str] = config["TIME_GRAIN_DENYLIST"]
        for key in denylist:
            time_grain_expressions.pop(key)

        return dict(
            sorted(
                time_grain_expressions.items(),
                key=lambda x: (
                    cls._sort_time_grains(x, 0),
                    cls._sort_time_grains(x, 1),
                    cls._sort_time_grains(x, 2),
                    cls._sort_time_grains(x, 3),
                ),
            )
        )

    @classmethod
    def make_select_compatible(
        cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement]
    ) -> List[ColumnElement]:
        """
        Some databases will just return the group-by field into the select, but don't
        allow the group-by field to be put into the select list.

        :param groupby_exprs: mapping between column name and column object
        :param select_exprs: all columns in the select clause
        :return: columns to be included in the final select clause
        """
        return select_exprs

    @classmethod
    def fetch_data(
        cls, cursor: Any, limit: Optional[int] = None
    ) -> List[Tuple[Any, ...]]:
        """

        :param cursor: Cursor instance
        :param limit: Maximum number of rows to be returned by the cursor
        :return: Result of query
        """
        if cls.arraysize:
            cursor.arraysize = cls.arraysize
        try:
            if cls.limit_method == LimitMethod.FETCH_MANY and limit:
                return cursor.fetchmany(limit)
            return cursor.fetchall()
        except Exception as ex:
            raise cls.get_dbapi_mapped_exception(ex)

    @classmethod
    def expand_data(
        cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
    ) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
        """
        Some engines support expanding nested fields. See implementation in Presto
        spec for details.

        :param columns: columns selected in the query
        :param data: original data set
        :return: list of all columns(selected columns and their nested fields),
                 expanded data set, listed of nested fields
        """
        return columns, data, []

    @classmethod
    def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
        """Allow altering default column attributes when first detected/added

        For instance special column like `__time` for Druid can be
        set to is_dttm=True. Note that this only gets called when new
        columns are detected/created"""
        # TODO: Fix circular import caused by importing TableColumn

    @classmethod
    def epoch_to_dttm(cls) -> str:
        """
        SQL expression that converts epoch (seconds) to datetime that can be used in a
        query. The reference column should be denoted as `{col}` in the return
        expression, e.g. "FROM_UNIXTIME({col})"

        :return: SQL Expression
        """
        raise NotImplementedError()

    @classmethod
    def epoch_ms_to_dttm(cls) -> str:
        """
        SQL expression that converts epoch (milliseconds) to datetime that can be used
        in a query.

        :return: SQL Expression
        """
        return cls.epoch_to_dttm().replace("{col}", "({col}/1000)")

    @classmethod
    def get_datatype(cls, type_code: Any) -> Optional[str]:
        """
        Change column type code from cursor description to string representation.

        :param type_code: Type code from cursor description
        :return: String representation of type code
        """
        if isinstance(type_code, str) and type_code != "":
            return type_code.upper()
        return None

    @classmethod
    def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Normalizes indexes for more consistency across db engines

        noop by default

        :param indexes: Raw indexes as returned by SQLAlchemy
        :return: cleaner, more aligned index definition
        """
        return indexes

    @classmethod
    def extra_table_metadata(
        cls, database: "Database", table_name: str, schema_name: str
    ) -> Dict[str, Any]:
        """
        Returns engine-specific table metadata

        :param database: Database instance
        :param table_name: Table name
        :param schema_name: Schema name
        :return: Engine-specific table metadata
        """
        # TODO: Fix circular import caused by importing Database
        return {}

    @classmethod
    def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
        """
        Alters the SQL statement to apply a LIMIT clause

        :param sql: SQL query
        :param limit: Maximum number of rows to be returned by the query
        :param database: Database instance
        :return: SQL query with limit clause
        """
        # TODO: Fix circular import caused by importing Database
        if cls.limit_method == LimitMethod.WRAP_SQL:
            sql = sql.strip("\t\n ;")
            qry = (
                select("*")
                .select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry"))
                .limit(limit)
            )
            return database.compile_sqla_query(qry)

        if cls.limit_method == LimitMethod.FORCE_LIMIT:
            parsed_query = sql_parse.ParsedQuery(sql)
            sql = parsed_query.set_or_update_query_limit(limit)

        return sql

    @classmethod
    def get_limit_from_sql(cls, sql: str) -> Optional[int]:
        """
        Extract limit from SQL query

        :param sql: SQL query
        :return: Value of limit clause in query
        """
        parsed_query = sql_parse.ParsedQuery(sql)
        return parsed_query.limit

    @classmethod
    def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
        """
        Create a query based on original query but with new limit clause

        :param sql: SQL query
        :param limit: New limit to insert/replace into query
        :return: Query with new limit
        """
        parsed_query = sql_parse.ParsedQuery(sql)
        return parsed_query.set_or_update_query_limit(limit)

    @staticmethod
    def csv_to_df(**kwargs: Any) -> pd.DataFrame:
        """Read csv into Pandas DataFrame
        :param kwargs: params to be passed to DataFrame.read_csv
        :return: Pandas DataFrame containing data from csv
        """
        kwargs["encoding"] = "utf-8"
        kwargs["iterator"] = True
        chunks = pd.read_csv(**kwargs)
        df = pd.concat(chunk for chunk in chunks)
        return df

    @classmethod
    def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None:
        """Upload data from a Pandas DataFrame to a database. For
        regular engines this calls the DataFrame.to_sql() method. Can be
        overridden for engines that don't work well with to_sql(), e.g.
        BigQuery.
        :param df: Dataframe with data to be uploaded
        :param kwargs: kwargs to be passed to to_sql() method
        """
        df.to_sql(**kwargs)

    @classmethod
    def create_table_from_csv(  # pylint: disable=too-many-arguments
        cls,
        filename: str,
        table: Table,
        database: "Database",
        csv_to_df_kwargs: Dict[str, Any],
        df_to_sql_kwargs: Dict[str, Any],
    ) -> None:
        """
        Create table from contents of a csv. Note: this method does not create
        metadata for the table.
        """
        df = cls.csv_to_df(filepath_or_buffer=filename, **csv_to_df_kwargs)
        engine = cls.get_engine(database)
        if table.schema:
            # only add schema when it is preset and non empty
            df_to_sql_kwargs["schema"] = table.schema
        if engine.dialect.supports_multivalues_insert:
            df_to_sql_kwargs["method"] = "multi"
        cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs)

    @classmethod
    def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
        """
        Convert Python datetime object to a SQL expression

        :param target_type: The target type of expression
        :param dttm: The datetime object
        :return: The SQL expression
        """
        return None

    @classmethod
    def create_table_from_excel(  # pylint: disable=too-many-arguments
        cls,
        filename: str,
        table: Table,
        database: "Database",
        excel_to_df_kwargs: Dict[str, Any],
        df_to_sql_kwargs: Dict[str, Any],
    ) -> None:
        """
        Create table from contents of a excel. Note: this method does not create
        metadata for the table.
        """
        df = pd.read_excel(io=filename, **excel_to_df_kwargs)
        engine = cls.get_engine(database)
        if table.schema:
            # only add schema when it is preset and non empty
            df_to_sql_kwargs["schema"] = table.schema
        if engine.dialect.supports_multivalues_insert:
            df_to_sql_kwargs["method"] = "multi"
        cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs)

    @classmethod
    def get_all_datasource_names(
        cls, database: "Database", datasource_type: str
    ) -> List[utils.DatasourceName]:
        """Returns a list of all tables or views in database.

        :param database: Database instance
        :param datasource_type: Datasource_type can be 'table' or 'view'
        :return: List of all datasources in database or schema
        """
        # TODO: Fix circular import caused by importing Database
        schemas = database.get_all_schema_names(
            cache=database.schema_cache_enabled,
            cache_timeout=database.schema_cache_timeout,
            force=True,
        )
        all_datasources: List[utils.DatasourceName] = []
        for schema in schemas:
            if datasource_type == "table":
                all_datasources += database.get_all_table_names_in_schema(
                    schema=schema,
                    force=True,
                    cache=database.table_cache_enabled,
                    cache_timeout=database.table_cache_timeout,
                )
            elif datasource_type == "view":
                all_datasources += database.get_all_view_names_in_schema(
                    schema=schema,
                    force=True,
                    cache=database.table_cache_enabled,
                    cache_timeout=database.table_cache_timeout,
                )
            else:
                raise Exception(f"Unsupported datasource_type: {datasource_type}")
        return all_datasources

    @classmethod
    def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
        """Handle a live cursor between the execute and fetchall calls

        The flow works without this method doing anything, but it allows
        for handling the cursor and updating progress information in the
        query object"""
        # TODO: Fix circular import error caused by importing sql_lab.Query

    @classmethod
    def extract_error_message(cls, ex: Exception) -> str:
        return f"{cls.engine} error: {cls._extract_error_message(ex)}"

    @classmethod
    def _extract_error_message(cls, ex: Exception) -> str:
        """Extract error message for queries"""
        return utils.error_msg_from_exception(ex)

    @classmethod
    def extract_errors(cls, ex: Exception) -> List[Dict[str, Any]]:
        return [
            dataclasses.asdict(
                SupersetError(
                    error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
                    message=cls._extract_error_message(ex),
                    level=ErrorLevel.ERROR,
                    extra={"engine_name": cls.engine_name},
                )
            )
        ]

    @classmethod
    def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None:
        """
        Mutate the database component of the SQLAlchemy URI.

        The URI here represents the URI as entered when saving the database,
        ``selected_schema`` is the schema currently active presumably in
        the SQL Lab dropdown. Based on that, for some database engine,
        we can return a new altered URI that connects straight to the
        active schema, meaning the users won't have to prefix the object
        names by the schema name.

        Some databases engines have 2 level of namespacing: database and
        schema (postgres, oracle, mssql, ...)
        For those it's probably better to not alter the database
        component of the URI with the schema name, it won't work.

        Some database drivers like presto accept '{catalog}/{schema}' in
        the database component of the URL, that can be handled here.
        """

    @classmethod
    def patch(cls) -> None:
        """
        TODO: Improve docstring and refactor implementation in Hive
        """

    @classmethod
    def get_schema_names(cls, inspector: Inspector) -> List[str]:
        """
        Get all schemas from database

        :param inspector: SqlAlchemy inspector
        :return: All schemas in the database
        """
        return sorted(inspector.get_schema_names())

    @classmethod
    def get_table_names(
        cls, database: "Database", inspector: Inspector, schema: Optional[str]
    ) -> List[str]:
        """
        Get all tables from schema

        :param inspector: SqlAlchemy inspector
        :param schema: Schema to inspect. If omitted, uses default schema for database
        :return: All tables in schema
        """
        tables = inspector.get_table_names(schema)
        if schema and cls.try_remove_schema_from_table_name:
            tables = [re.sub(f"^{schema}\\.", "", table) for table in tables]
        return sorted(tables)

    @classmethod
    def get_view_names(
        cls, database: "Database", inspector: Inspector, schema: Optional[str]
    ) -> List[str]:
        """
        Get all views from schema

        :param inspector: SqlAlchemy inspector
        :param schema: Schema name. If omitted, uses default schema for database
        :return: All views in schema
        """
        views = inspector.get_view_names(schema)
        if schema and cls.try_remove_schema_from_table_name:
            views = [re.sub(f"^{schema}\\.", "", view) for view in views]
        return sorted(views)

    @classmethod
    def get_table_comment(
        cls, inspector: Inspector, table_name: str, schema: Optional[str]
    ) -> Optional[str]:
        """
        Get comment of table from a given schema and table

        :param inspector: SqlAlchemy Inspector instance
        :param table_name: Table name
        :param schema: Schema name. If omitted, uses default schema for database
        :return: comment of table
        """
        comment = None
        try:
            comment = inspector.get_table_comment(table_name, schema)
            comment = comment.get("text") if isinstance(comment, dict) else None
        except NotImplementedError:
            # It's expected that some dialects don't implement the comment method
            pass
        except Exception as ex:  # pylint: disable=broad-except
            logger.error("Unexpected error while fetching table comment")
            logger.exception(ex)
        return comment

    @classmethod
    def get_columns(
        cls, inspector: Inspector, table_name: str, schema: Optional[str]
    ) -> List[Dict[str, Any]]:
        """
        Get all columns from a given schema and table

        :param inspector: SqlAlchemy Inspector instance
        :param table_name: Table name
        :param schema: Schema name. If omitted, uses default schema for database
        :return: All columns in table
        """
        return inspector.get_columns(table_name, schema)

    @classmethod
    def where_latest_partition(  # pylint: disable=too-many-arguments
        cls,
        table_name: str,
        schema: Optional[str],
        database: "Database",
        query: Select,
        columns: Optional[List[Dict[str, str]]] = None,
    ) -> Optional[Select]:
        """
        Add a where clause to a query to reference only the most recent partition

        :param table_name: Table name
        :param schema: Schema name
        :param database: Database instance
        :param query: SqlAlchemy query
        :param columns: List of TableColumns
        :return: SqlAlchemy query with additional where clause referencing latest
        partition
        """
        # TODO: Fix circular import caused by importing Database, TableColumn
        return None

    @classmethod
    def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
        return [column(c["name"]) for c in cols]

    @classmethod
    def select_star(  # pylint: disable=too-many-arguments,too-many-locals
        cls,
        database: "Database",
        table_name: str,
        engine: Engine,
        schema: Optional[str] = None,
        limit: int = 100,
        show_cols: bool = False,
        indent: bool = True,
        latest_partition: bool = True,
        cols: Optional[List[Dict[str, Any]]] = None,
    ) -> str:
        """
        Generate a "SELECT * from [schema.]table_name" query with appropriate limit.

        WARNING: expects only unquoted table and schema names.

        :param database: Database instance
        :param table_name: Table name, unquoted
        :param engine: SqlALchemy Engine instance
        :param schema: Schema, unquoted
        :param limit: limit to impose on query
        :param show_cols: Show columns in query; otherwise use "*"
        :param indent: Add indentation to query
        :param latest_partition: Only query latest partition
        :param cols: Columns to include in query
        :return: SQL query
        """
        fields: Union[str, List[Any]] = "*"
        cols = cols or []
        if (show_cols or latest_partition) and not cols:
            cols = database.get_columns(table_name, schema)

        if show_cols:
            fields = cls._get_fields(cols)
        quote = engine.dialect.identifier_preparer.quote
        if schema:
            full_table_name = quote(schema) + "." + quote(table_name)
        else:
            full_table_name = quote(table_name)

        qry = select(fields).select_from(text(full_table_name))

        if limit:
            qry = qry.limit(limit)
        if latest_partition:
            partition_query = cls.where_latest_partition(
                table_name, schema, database, qry, columns=cols
            )
            if partition_query is not None:
                qry = partition_query
        sql = database.compile_sqla_query(qry)
        if indent:
            sql = sqlparse.format(sql, reindent=True)
        return sql

    @classmethod
    def estimate_statement_cost(cls, statement: str, cursor: Any,) -> Dict[str, Any]:
        """
        Generate a SQL query that estimates the cost of a given statement.

        :param statement: A single SQL statement
        :param cursor: Cursor instance
        :return: Dictionary with different costs
        """
        raise Exception("Database does not support cost estimation")

    @classmethod
    def query_cost_formatter(
        cls, raw_cost: List[Dict[str, Any]]
    ) -> List[Dict[str, str]]:
        """
        Format cost estimate.

        :param raw_cost: Raw estimate from `estimate_query_cost`
        :return: Human readable cost estimate
        """
        raise Exception("Database does not support cost estimation")

    @classmethod
    def process_statement(
        cls, statement: str, database: "Database", user_name: str
    ) -> str:
        """
        Process a SQL statement by stripping and mutating it.

        :param statement: A single SQL statement
        :param database: Database instance
        :param username: Effective username
        :return: Dictionary with different costs
        """
        parsed_query = ParsedQuery(statement)
        sql = parsed_query.stripped()
        sql_query_mutator = config["SQL_QUERY_MUTATOR"]
        if sql_query_mutator:
            sql = sql_query_mutator(sql, user_name, security_manager, database)

        return sql

    @classmethod
    def estimate_query_cost(
        cls, database: "Database", schema: str, sql: str, source: Optional[str] = None
    ) -> List[Dict[str, Any]]:
        """
        Estimate the cost of a multiple statement SQL query.

        :param database: Database instance
        :param schema: Database schema
        :param sql: SQL query with possibly multiple statements
        :param source: Source of the query (eg, "sql_lab")
        """
        extra = database.get_extra() or {}
        if not cls.get_allow_cost_estimate(extra):
            raise Exception("Database does not support cost estimation")

        user_name = g.user.username if g.user else None
        parsed_query = sql_parse.ParsedQuery(sql)
        statements = parsed_query.get_statements()

        engine = cls.get_engine(database, schema=schema, source=source)
        costs = []
        with closing(engine.raw_connection()) as conn:
            cursor = conn.cursor()
            for statement in statements:
                processed_statement = cls.process_statement(
                    statement, database, user_name
                )
                costs.append(cls.estimate_statement_cost(processed_statement, cursor))
        return costs

    @classmethod
    def modify_url_for_impersonation(
        cls, url: URL, impersonate_user: bool, username: Optional[str]
    ) -> None:
        """
        Modify the SQL Alchemy URL object with the user to impersonate if applicable.
        :param url: SQLAlchemy URL object
        :param impersonate_user: Flag indicating if impersonation is enabled
        :param username: Effective username
        """
        if impersonate_user and username is not None:
            url.username = username

    @classmethod
    def update_impersonation_config(
        cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
    ) -> None:
        """
        Update a configuration dictionary
        that can set the correct properties for impersonating users

        :param connect_args: config to be updated
        :param uri: URI
        :param impersonate_user: Flag indicating if impersonation is enabled
        :param username: Effective username
        :return: None
        """

    @classmethod
    def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None:
        """
        Execute a SQL query

        :param cursor: Cursor instance
        :param query: Query to execute
        :param kwargs: kwargs to be passed to cursor.execute()
        :return:
        """
        if not cls.allows_sql_comments:
            query = sql_parse.strip_comments_from_sql(query)

        if cls.arraysize:
            cursor.arraysize = cls.arraysize
        try:
            cursor.execute(query)
        except Exception as ex:
            raise cls.get_dbapi_mapped_exception(ex)

    @classmethod
    def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
        """
        Conditionally mutate and/or quote a sqlalchemy expression label. If
        force_column_alias_quotes is set to True, return the label as a
        sqlalchemy.sql.elements.quoted_name object to ensure that the select query
        and query results have same case. Otherwise return the mutated label as a
        regular string. If maxmimum supported column name length is exceeded,
        generate a truncated label by calling truncate_label().

        :param label: expected expression label/alias
        :return: conditionally mutated label supported by the db engine
        """
        label_mutated = cls._mutate_label(label)
        if (
            cls.max_column_name_length
            and len(label_mutated) > cls.max_column_name_length
        ):
            label_mutated = cls._truncate_label(label)
        if cls.force_column_alias_quotes:
            label_mutated = quoted_name(label_mutated, True)
        return label_mutated

    @classmethod
    def get_sqla_column_type(
        cls,
        column_type: Optional[str],
        column_type_mappings: Tuple[
            Tuple[
                Pattern[str],
                Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
                GenericDataType,
            ],
            ...,
        ] = column_type_mappings,
    ) -> Union[Tuple[TypeEngine, GenericDataType], None]:
        """
        Return a sqlalchemy native column type that corresponds to the column type
        defined in the data source (return None to use default type inferred by
        SQLAlchemy). Override `column_type_mappings` for specific needs
        (see MSSQL for example of NCHAR/NVARCHAR handling).

        :param column_type: Column type returned by inspector
        :return: SqlAlchemy column type
        """
        if not column_type:
            return None
        for regex, sqla_type, generic_type in column_type_mappings:
            match = regex.match(column_type)
            if match:
                if callable(sqla_type):
                    return sqla_type(match), generic_type
                return sqla_type, generic_type
        return None

    @staticmethod
    def _mutate_label(label: str) -> str:
        """
        Most engines support mixed case aliases that can include numbers
        and special characters, like commas, parentheses etc. For engines that
        have restrictions on what types of aliases are supported, this method
        can be overridden to ensure that labels conform to the engine's
        limitations. Mutated labels should be deterministic (input label A always
        yields output label X) and unique (input labels A and B don't yield the same
        output label X).

        :param label: Preferred expression label
        :return: Conditionally mutated label
        """
        return label

    @classmethod
    def _truncate_label(cls, label: str) -> str:
        """
        In the case that a label exceeds the max length supported by the engine,
        this method is used to construct a deterministic and unique label based on
        the original label. By default this returns an md5 hash of the original label,
        conditionally truncated if the length of the hash exceeds the max column length
        of the engine.

        :param label: Expected expression label
        :return: Truncated label
        """
        label = hashlib.md5(label.encode("utf-8")).hexdigest()
        # truncate hash if it exceeds max length
        if cls.max_column_name_length and len(label) > cls.max_column_name_length:
            label = label[: cls.max_column_name_length]
        return label

    @classmethod
    def column_datatype_to_string(
        cls, sqla_column_type: TypeEngine, dialect: Dialect
    ) -> str:
        """
        Convert sqlalchemy column type to string representation.
        By default removes collation and character encoding info to avoid unnecessarily
        long datatypes.

        :param sqla_column_type: SqlAlchemy column type
        :param dialect: Sqlalchemy dialect
        :return: Compiled column type
        """
        sqla_column_type = sqla_column_type.copy()
        if hasattr(sqla_column_type, "collation"):
            sqla_column_type.collation = None
        if hasattr(sqla_column_type, "charset"):
            sqla_column_type.charset = None
        return sqla_column_type.compile(dialect=dialect).upper()

    @classmethod
    def get_function_names(cls, database: "Database") -> List[str]:
        """
        Get a list of function names that are able to be called on the database.
        Used for SQL Lab autocomplete.

        :param database: The database to get functions for
        :return: A list of function names useable in the database
        """
        return []

    @staticmethod
    def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
        """
        Convert pyodbc.Row objects from `fetch_data` to tuples.

        :param data: List of tuples or pyodbc.Row objects
        :return: List of tuples
        """
        if data and type(data[0]).__name__ == "Row":
            data = [tuple(row) for row in data]
        return data

    @staticmethod
    def mutate_db_for_connection_test(database: "Database") -> None:
        """
        Some databases require passing additional parameters for validating database
        connections. This method makes it possible to mutate the database instance prior
        to testing if a connection is ok.

        :param database: instance to be mutated
        """
        return None

    @staticmethod
    def get_extra_params(database: "Database") -> Dict[str, Any]:
        """
        Some databases require adding elements to connection parameters,
        like passing certificates to `extra`. This can be done here.

        :param database: database instance from which to extract extras
        :raises CertificateException: If certificate is not valid/unparseable
        """
        extra: Dict[str, Any] = {}
        if database.extra:
            try:
                extra = json.loads(database.extra)
            except json.JSONDecodeError as ex:
                logger.error(ex)
                raise ex
        return extra

    @classmethod
    def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
        """Pessimistic readonly, 100% sure statement won't mutate anything"""
        return (
            parsed_query.is_select()
            or parsed_query.is_explain()
            or parsed_query.is_show()
        )

    @classmethod
    def get_column_spec(
        cls,
        native_type: Optional[str],
        source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
        column_type_mappings: Tuple[
            Tuple[
                Pattern[str],
                Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
                GenericDataType,
            ],
            ...,
        ] = column_type_mappings,
    ) -> Union[ColumnSpec, None]:
        """
        Converts native database type to sqlalchemy column type.
        :param native_type: Native database typee
        :param source: Type coming from the database table or cursor description
        :return: ColumnSpec object
        """
        col_types = cls.get_sqla_column_type(
            native_type, column_type_mappings=column_type_mappings
        )
        if col_types:
            column_type, generic_type = col_types
            # wrap temporal types in custom type that supports literal binding
            # using datetimes
            if generic_type == GenericDataType.TEMPORAL:
                column_type = literal_dttm_type_factory(
                    type(column_type), cls, native_type or ""
                )
            is_dttm = generic_type == GenericDataType.TEMPORAL
            return ColumnSpec(
                sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
            )
        return None
Exemple #10
0
def init_schema(url):
    """
    Setup the songs database connection and initialise the database schema.

    :param url: The database to setup
    The song database contains the following tables:

        * authors
        * authors_songs
        * media_files
        * media_files_songs
        * song_books
        * songs
        * songs_topics
        * topics

    **authors** Table
        This table holds the names of all the authors. It has the following
        columns:

        * id
        * first_name
        * last_name
        * display_name

    **authors_songs Table**
        This is a bridging table between the *authors* and *songs* tables, which
        serves to create a many-to-many relationship between the two tables. It
        has the following columns:

        * author_id
        * song_id
        * author_type

    **media_files Table**
        * id
        * file_name
        * type

    **song_books Table**
        The *song_books* table holds a list of books that a congregation gets
        their songs from, or old hymnals now no longer used. This table has the
        following columns:

        * id
        * name
        * publisher

    **songs Table**
        This table contains the songs, and each song has a list of attributes.
        The *songs* table has the following columns:

        * id
        * song_book_id
        * title
        * alternate_title
        * song_key
        * transpose_by
        * lyrics
        * verse_order
        * copyright
        * comments
        * ccli_number
        * song_number
        * theme_name
        * search_title
        * search_lyrics

    **songs_topics Table**
        This is a bridging table between the *songs* and *topics* tables, which
        serves to create a many-to-many relationship between the two tables. It
        has the following columns:

        * song_id
        * topic_id

    **topics Table**
        The topics table holds a selection of topics that songs can cover. This
        is useful when a worship leader wants to select songs with a certain
        theme. This table has the following columns:

        * id
        * name
    """
    session, metadata = init_db(url)

    # Definition of the "authors" table
    authors_table = Table(
        'authors', metadata,
        Column('id', types.Integer(), primary_key=True),
        Column('first_name', types.Unicode(128)),
        Column('last_name', types.Unicode(128)),
        Column('display_name', types.Unicode(255), index=True, nullable=False)
    )

    # Definition of the "media_files" table
    media_files_table = Table(
        'media_files', metadata,
        Column('id', types.Integer(), primary_key=True),
        Column('song_id', types.Integer(), ForeignKey('songs.id'), default=None),
        Column('file_name', types.Unicode(255), nullable=False),
        Column('type', types.Unicode(64), nullable=False, default='audio'),
        Column('weight', types.Integer(), default=0)
    )

    # Definition of the "song_books" table
    song_books_table = Table(
        'song_books', metadata,
        Column('id', types.Integer(), primary_key=True),
        Column('name', types.Unicode(128), nullable=False),
        Column('publisher', types.Unicode(128))
    )

    # Definition of the "songs" table
    songs_table = Table(
        'songs', metadata,
        Column('id', types.Integer(), primary_key=True),
        Column('song_book_id', types.Integer(), ForeignKey('song_books.id'), default=None),
        Column('title', types.Unicode(255), nullable=False),
        Column('alternate_title', types.Unicode(255)),
        Column('song_key', types.Unicode(3)),
        Column('transpose_by', types.Integer(), default=0),
        Column('chords', types.UnicodeText),
        Column('lyrics', types.UnicodeText, nullable=False),
        Column('verse_order', types.Unicode(128)),
        Column('copyright', types.Unicode(255)),
        Column('comments', types.UnicodeText),
        Column('ccli_number', types.Unicode(64)),
        Column('song_number', types.Unicode(64)),
        Column('theme_name', types.Unicode(128)),
        Column('search_title', types.Unicode(255), index=True, nullable=False),
        Column('search_lyrics', types.UnicodeText, nullable=False),
        Column('create_date', types.DateTime(), default=func.now()),
        Column('last_modified', types.DateTime(), default=func.now(), onupdate=func.now()),
        Column('temporary', types.Boolean(), default=False)
    )

    # Definition of the "topics" table
    topics_table = Table(
        'topics', metadata,
        Column('id', types.Integer(), primary_key=True),
        Column('name', types.Unicode(128), index=True, nullable=False)
    )

    # Definition of the "authors_songs" table
    authors_songs_table = Table(
        'authors_songs', metadata,
        Column('author_id', types.Integer(), ForeignKey('authors.id'), primary_key=True),
        Column('song_id', types.Integer(), ForeignKey('songs.id'), primary_key=True),
        Column('author_type', types.Unicode(255), primary_key=True, nullable=False, server_default=text('""'))
    )

    # Definition of the "songs_topics" table
    songs_topics_table = Table(
        'songs_topics', metadata,
        Column('song_id', types.Integer(), ForeignKey('songs.id'), primary_key=True),
        Column('topic_id', types.Integer(), ForeignKey('topics.id'), primary_key=True)
    )

    mapper(Author, authors_table, properties={
        'songs': relation(Song, secondary=authors_songs_table, viewonly=True)
    })
    mapper(AuthorSong, authors_songs_table, properties={
        'author': relation(Author)
    })
    mapper(Book, song_books_table)
    mapper(MediaFile, media_files_table)
    mapper(Song, songs_table, properties={
        # Use the authors_songs relation when you need access to the 'author_type' attribute
        # or when creating new relations
        'authors_songs': relation(AuthorSong, cascade="all, delete-orphan"),
        # Use lazy='joined' to always load authors when the song is fetched from the database (bug 1366198)
        'authors': relation(Author, secondary=authors_songs_table, viewonly=True, lazy='joined'),
        'book': relation(Book, backref='songs'),
        'media_files': relation(MediaFile, backref='songs', order_by=media_files_table.c.weight),
        'topics': relation(Topic, backref='songs', secondary=songs_topics_table)
    })
    mapper(Topic, topics_table)

    metadata.create_all(checkfirst=True)
    return session
Exemple #11
0
    Column('serial', types.Integer, Sequence("serial_seq")),
    Column('name', types.String(60), nullable=False, default=name_default),
    Column('t_list',
           types.ARRAY(types.String(60)),
           nullable=False,
           default=t_list_default),
    Column('t_enum',
           types.Enum(MyEnum),
           nullable=False,
           default=t_enum_default),
    Column('t_int_enum',
           types.Enum(MyIntEnum),
           nullable=False,
           default=t_int_enum_default),
    Column('t_datetime',
           types.DateTime(),
           nullable=False,
           default=t_datetime_default),
    Column('t_date', types.DateTime(), nullable=False, default=t_date_default),
    Column('t_date_2',
           types.DateTime(),
           nullable=False,
           default=t_date_2_default),
    Column('t_interval',
           types.Interval(),
           nullable=False,
           default=t_interval_default),
    Column('t_boolean', types.Boolean(), nullable=False, default=True),
    Column('version', PG_UUID, default=uuid4, onupdate=uuid4))

Exemple #12
0
class BearerToken(Base):
    __tablename__ = 'oauth2_bearer_token'

    access_token = Column(types.Unicode(100), primary_key=True)
    client_pk = Column(GUID,
                       ForeignKey(Client.pk,
                                  onupdate='CASCADE',
                                  ondelete='CASCADE'),
                       nullable=False)
    user_pk = Column(GUID,
                     ForeignKey(User.pk,
                                onupdate='CASCADE',
                                ondelete='CASCADE'),
                     nullable=False)
    scopes = Column(postgresql.ARRAY(types.Unicode(80)), nullable=False)
    expires_at = Column(types.DateTime(timezone=True), nullable=True)
    refresh_token = Column(types.Unicode(100), nullable=True, unique=True)

    client = orm.relationship(Client)
    user = orm.relationship(User)

    @property
    def client_id(self):
        return self.client_pk.hex

    @hybrid_property
    def expires(self):
        """

        .. seealso:: https://flask-oauthlib.readthedocs.org/en/latest/api.html#flask_oauthlib.provider.OAuth2Provider.tokengetter

        """
        return self.expires_at.astimezone(tzutc()).replace(tzinfo=None)

    def __init__(self,
                 client,
                 user,
                 access_token,
                 scopes,
                 expires_in,
                 refresh_token=None):
        """

        :param client:
        :type client: Client
        :param user:
        :type user: User
        :param access_token:
        :type access_token: basestring
        :keyword scopes:
        :type scopes: collections.Iterable
        :keyword expires_in:
        :type expires_in: datetime.timedelta

        """
        if not isinstance(expires_in, timedelta):
            expires_in = timedelta(seconds=expires_in)
        self.client = client
        self.user = user
        self.access_token = access_token
        self.scopes = scopes
        self.expires_at = datetime.now(tzutc()) + expires_in
        self.refresh_token = refresh_token
Exemple #13
0
class User(db.Model, UserMixin, DeploymentMixin):  #pylint: disable=no-init,too-few-public-methods
    '''
    User
    '''
    __tablename__ = 'users'

    id = Column(types.Integer, autoincrement=True, primary_key=True)  #pylint: disable=invalid-name

    picture_id = Column(
        types.String,
        default=lambda: base64.urlsafe_b64encode(os.urandom(20))[0:-2])

    has_picture = Column(types.Boolean, default=False)

    first_name = Column(types.String,
                        info={
                            'label': lazy_gettext('First Name'),
                        })
    last_name = Column(types.String,
                       info={
                           'label': lazy_gettext('Last Name'),
                       })

    email = Column(EmailType,
                   nullable=False,
                   info={
                       'label': lazy_gettext('Email'),
                   })

    password = Column(types.String,
                      nullable=False,
                      info={
                          'label': lazy_gettext('Password'),
                      })
    active = Column(types.Boolean, nullable=False)

    last_login_at = Column(types.DateTime())
    current_login_at = Column(types.DateTime())
    confirmed_at = Column(types.DateTime())
    last_login_ip = Column(types.Text)
    current_login_ip = Column(types.Text)
    login_count = Column(types.Integer)

    position = Column(types.String, info={
        'label': lazy_gettext('Position'),
    })
    organization = Column(types.String,
                          info={
                              'label': lazy_gettext('Organization'),
                          })
    organization_type = Column(
        types.String,
        info={
            'label': lazy_gettext('Type of Organization'),
            'description':
            lazy_gettext('The type of organization you work for'),
            'choices': [(k, v) for k, v in ORG_TYPES.iteritems()]
        })
    country = Column(CountryType, info={
        'label': lazy_gettext('Country'),
    })

    city = Column(types.String, info={'label': lazy_gettext('City')})

    latlng = Column(types.String,
                    info={
                        'label': lazy_gettext('Location'),
                        'description': lazy_gettext('Enter your location')
                    })

    projects = Column(
        types.Text,
        info={
            'label':
            lazy_gettext('Projects'),
            'description':
            lazy_gettext(
                'Add name and url or short description of any current work projects.'
            )
        })

    tutorial_step = Column(types.Integer())

    created_at = Column(types.DateTime(), default=datetime.datetime.now)
    updated_at = Column(types.DateTime(),
                        default=datetime.datetime.now,
                        onupdate=datetime.datetime.now)

    def is_admin(self):
        return self.email in current_app.config.get('ADMIN_UI_USERS', [])

    @property
    def full_name(self):
        return u"%s %s" % (self.first_name, self.last_name)

    @property
    def full_location(self):
        loc = []
        if self.city:
            loc.append(self.city)
        if self.country and self.country.code != 'ZZ':
            loc.append(self.country.name)
        return ", ".join(loc)

    @property
    def display_in_search(self):
        '''
        Determine whether user has filled out bare minimum to display in search
        results.

        Specifically, we want to make sure that the first and last name
        are both non-NULL and non-blank.
        '''
        return bool(self.first_name and self.last_name)

    @property
    def picture_path(self):
        '''
        Path where picture would be found (hosted on S3).
        '''
        return "{}/static/pictures/{}/{}".format(
            current_app.config['NOI_DEPLOY'], self.id, self.picture_id)

    @property
    def picture_url(self):
        '''
        Full path to picture.
        '''
        return 'https://s3.amazonaws.com/{bucket}/{path}'.format(
            bucket=current_app.config['S3_BUCKET_NAME'],
            path=self.picture_path)

    def remove_picture(self):
        conn = S3Connection(current_app.config['S3_ACCESS_KEY_ID'],
                            current_app.config['S3_SECRET_ACCESS_KEY'])
        bucket = conn.get_bucket(current_app.config['S3_BUCKET_NAME'])

        if bucket.get_key(self.picture_path):
            bucket.delete_key(self.picture_path)

        self.has_picture = False

    def upload_picture(self, fileobj, mimetype):
        '''
        Upload the given file object with the given mime type to S3 and
        mark the user as having a picture.
        '''

        conn = S3Connection(current_app.config['S3_ACCESS_KEY_ID'],
                            current_app.config['S3_SECRET_ACCESS_KEY'])
        bucket = conn.get_bucket(current_app.config['S3_BUCKET_NAME'])
        bucket.make_public(recursive=False)

        k = bucket.new_key(self.picture_path)
        k.set_metadata('Content-Type', mimetype)
        k.set_contents_from_file(fileobj)
        k.make_public()

        self.has_picture = True

    @property
    def helpful_users(self, limit=10):
        '''
        Returns a list of (user, score) tuples with matching positive skills,
        ordered by the most helpful (highest score) descending.
        '''
        my_skills = aliased(UserSkill, name='my_skills', adapt_on_names=True)
        their_skills = aliased(UserSkill,
                               name='their_skills',
                               adapt_on_names=True)

        return User.query_in_deployment().\
                add_column(func.sum(their_skills.level - my_skills.level)).\
                filter(their_skills.user_id != my_skills.user_id).\
                filter(User.id == their_skills.user_id).\
                filter(their_skills.name == my_skills.name).\
                filter(my_skills.user_id == self.id).\
                filter(my_skills.level == LEVELS['LEVEL_I_WANT_TO_LEARN']['score']).\
                group_by(User).\
                order_by(desc(func.sum(their_skills.level - my_skills.level))).\
                limit(limit)

    @property
    def nearest_neighbors(self, limit=10):
        '''
        Returns a list of (user, score) tuples with the closest matching
        skills.  If they haven't answered the equivalent skill question, we
        consider that a very big difference (12).

        Order is closest to least close, which is an ascending score.
        '''
        my_skills = aliased(UserSkill, name='my_skills', adapt_on_names=True)
        their_skills = aliased(UserSkill,
                               name='their_skills',
                               adapt_on_names=True)

        # difference we assume for user that has not answered question
        unanswered_difference = (LEVELS['LEVEL_I_CAN_DO_IT']['score'] -
                                 LEVELS['LEVEL_I_WANT_TO_LEARN']['score']) * 2

        return User.query_in_deployment().\
                add_column(((len(self.skills) - func.count(func.distinct(their_skills.id))) *
                            unanswered_difference) + \
                       func.sum(func.abs(their_skills.level - my_skills.level))).\
                filter(their_skills.user_id != my_skills.user_id).\
                filter(User.id == their_skills.user_id).\
                filter(their_skills.name == my_skills.name).\
                filter(my_skills.user_id == self.id).\
                group_by(User).\
                order_by(((len(self.skills) - func.count(func.distinct(their_skills.id)))
                          * unanswered_difference) + \
                     func.sum(func.abs(their_skills.level - my_skills.level))).\
                limit(limit)

    @property
    def has_fully_registered(self):
        '''
        Returns whether the user has fully completed the registration/signup
        flow.
        '''

        return db.session.query(UserJoinedEvent).\
               filter_by(user_id=self.id).\
               first() is not None

    def set_fully_registered(self):
        '''
        Marks the user as having fully completed the registration/signup
        flow, if they haven't already.
        '''

        if self.has_fully_registered:
            return
        db.session.add(UserJoinedEvent.from_user(self))

    def match(self, level, limit=10):
        '''
        Returns a list of UserSkillMatch objects, in descending order of number
        of skills matched for each user.
        '''

        skills_to_learn = [
            s.name for s in self.skills
            if s.level == LEVELS['LEVEL_I_WANT_TO_LEARN']['score']
        ]
        if skills_to_learn:
            matched_users = User.query_in_deployment().\
                            add_column(func.string_agg(UserSkill.name, ',')).\
                            add_column(func.count(UserSkill.id)).\
                            filter(UserSkill.name.in_(skills_to_learn)).\
                            filter(User.id == UserSkill.user_id).\
                            filter(UserSkill.level == level).\
                            filter(UserSkill.user_id != self.id).\
                            group_by(User).\
                            order_by(func.count().desc()).\
                            limit(limit)
        else:
            matched_users = []

        for user, question_ids_by_comma, count in matched_users:
            yield UserSkillMatch(user, question_ids_by_comma.split(','))

    def match_against(self, user):
        '''
        Returns a list of three-tuples in the format:

        (<questionnaire id>, <count of matching questions>, <skill dict>, )

        In descending order of <count of matching questions>.

        <skill dict> is keyed by the skill level of the other user, with each
        value being a set of questions they can answer at that level.
        '''
        skills = UserSkill.query.\
                filter(UserSkill.user_id == user.id).\
                filter(UserSkill.name.in_(
                    [s.name for s in
                     self.skills if s.level == LEVELS['LEVEL_I_WANT_TO_LEARN']['score']
                    ])).all()

        resp = {}
        for skill in skills:
            question = QUESTIONS_BY_ID[skill.name]
            questionnaire_id = question['questionnaire']['id']
            if questionnaire_id not in resp:
                resp[questionnaire_id] = dict()

            if skill.level not in resp[questionnaire_id]:
                resp[questionnaire_id][skill.level] = set()
            resp[questionnaire_id][skill.level].add(skill.name)

        resp = [(qname,
                 sum([len(questions)
                      for questions in skill_levels.values()]), skill_levels)
                for qname, skill_levels in resp.items()]

        return sorted(resp, lambda a, b: a[1] - b[1], reverse=True)

    def match_against_with_progress(self, user):
        '''
        Like match_against(), but also includes information about
        areas of expertise the target user has that we don't match
        on.
        '''

        progress = user.questionnaire_progress
        matches = self.match_against(user)
        match_areas = {}

        for questionnaire_id, _, _ in matches:
            match_areas[questionnaire_id] = True

        for questionnaire_id, progress in progress.items():
            if (questionnaire_id not in match_areas
                    and progress['answered'] > 0):
                matches.append((questionnaire_id, 0, {}))

        return matches

    def match_against_with_progress_in_area(self, user, areaid):
        '''
        Return a tuple of (matched_skill_dict, unmatched_skill_dict)
        for the given area.

        matched_skill_dict contains information about skills that
        the target user has which we want to learn, while
        unmatched_skill_dict contains all other skills the target
        user has.

        Each dict is keyed by the skill level of the other user,
        with each value being a set of questions they can answer at that
        level.
        '''

        questionnaire = QUESTIONNAIRES_BY_ID[areaid]
        matches = self.match_against(user)

        matched_skill_dict = {}
        matched_skills = {}

        for questionnaire_id, _, skill_dict in matches:
            if questionnaire_id == areaid:
                matched_skill_dict = skill_dict
                break

        for question_ids in matched_skill_dict.values():
            for question_id in question_ids:
                matched_skills[question_id] = True

        unmatched_skill_dict = {}
        skill_levels = user.skill_levels
        for topic in questionnaire.get('topics', []):
            for question in topic['questions']:
                qid = question['id']
                if (qid in skill_levels and qid not in matched_skills):
                    level = skill_levels[qid]
                    if level not in unmatched_skill_dict:
                        unmatched_skill_dict[level] = []
                    unmatched_skill_dict[level].append(qid)

        return (matched_skill_dict, unmatched_skill_dict)

    @property
    def questionnaire_progress(self):
        '''
        Return a dictionary mapping top-level skill area IDs (e.g.,
        'opendata', 'prizes') to information about how many questions
        the user has answered in that skill area.
        '''

        skill_levels = self.skill_levels
        progress = {}
        for questionnaire in QUESTIONNAIRES:
            topic_progress = {'answered': 0, 'total': 0}
            progress[questionnaire['id']] = topic_progress
            for topic in questionnaire.get('topics', []):
                for question in topic['questions']:
                    topic_progress['total'] += 1
                    if question['id'] in skill_levels:
                        topic_progress['answered'] += 1
        return progress

    @property
    def skill_levels(self):
        '''
        Dictionary of this user's entered skills, keyed by the id of the skill.
        '''
        return dict([(skill.name, skill.level) for skill in self.skills])

    @property
    def connections(self):
        '''
        Count the number of distinct email addresses this person has sent or
        received messages from in the deployment.
        '''
        sent = db.session.query(func.count(func.distinct(Email.to_user_id))).\
                filter(Email.to_user_id != self.id).\
                filter(Email.from_user_id == self.id).first()[0]
        received = db.session.query(func.count(func.distinct(Email.from_user_id))).\
                filter(Email.from_user_id != self.id).\
                filter(Email.to_user_id == self.id).first()[0]
        return sent + received

    def set_skill(self, skill_name, skill_level):
        '''
        Set the level of a single skill by name.
        '''
        if skill_name not in QUESTIONS_BY_ID:
            return
        try:
            if int(skill_level) not in VALID_SKILL_LEVELS:
                return
        except ValueError:
            return
        for skill in self.skills:
            if skill_name == skill.name:
                skill.level = skill_level
                db.session.add(skill)
                return
        db.session.add(
            UserSkill(user_id=self.id, name=skill_name, level=skill_level))

    def email_connect(self, users):
        '''
        Indicate that this user has opened an email window with this list of
        users as recipients.
        '''
        event = ConnectionEvent()
        for user in users:
            event.emails.append(Email(from_user_id=self.id,
                                      to_user_id=user.id))

        db.session.add(event)
        return event

    roles = orm.relationship('Role',
                             secondary='role_users',
                             backref=orm.backref('users', lazy='dynamic'))

    expertise_domains = orm.relationship('UserExpertiseDomain',
                                         cascade='all,delete-orphan',
                                         backref='user')
    languages = orm.relationship('UserLanguage',
                                 cascade='all,delete-orphan',
                                 backref='user')
    skills = orm.relationship('UserSkill',
                              cascade='all,delete-orphan',
                              backref='user')

    @classmethod
    def get_most_complete_profiles(cls, limit=10):
        '''
        Obtain a list of most complete profiles, as (User, score) tuples.
        '''
        return User.query_in_deployment().\
                add_column(func.count(UserSkill.id)).\
                filter(User.id == UserSkill.user_id).\
                group_by(User).\
                order_by(func.count(UserSkill.id).desc()).\
                limit(limit)

    @classmethod
    def get_most_connected_profiles(cls, limit=10):
        '''
        Obtain a list of most connected profiles, as descending (User, score)
        tuples.
        '''
        count_of_unique_emails = func.count(
            func.distinct(
                cast(Email.to_user_id, String) + '-' +
                cast(Email.from_user_id, String)))
        return User.query_in_deployment().\
                add_column(count_of_unique_emails).\
                filter((User.id == Email.from_user_id) | (User.id == Email.to_user_id)).\
                group_by(User).\
                order_by(count_of_unique_emails.desc()).\
                limit(limit)

    @hybrid_property
    def expertise_domain_names(self):
        '''
        Convenient list of expertise domains by name.
        '''
        return [ed.name for ed in self.expertise_domains]

    @expertise_domain_names.setter
    def _expertise_domains_setter(self, values):
        '''
        Update expertise domains in bulk.  Values are array of names.
        '''
        # Only add new expertise
        for val in values:
            if val not in self.expertise_domain_names:
                db.session.add(UserExpertiseDomain(name=val, user_id=self.id))
        # delete expertise no longer found
        expertise_to_remove = []
        for exp in self.expertise_domains:
            if exp.name not in values:
                expertise_to_remove.append(exp)

        for exp in expertise_to_remove:
            self.expertise_domains.remove(exp)

    def get_area_scores(self):
        skill_levels = self.skill_levels
        result = {}

        NO_ANSWER = -999

        for questionnaire in QUESTIONNAIRES:
            if not questionnaire['questions']:
                continue
            max_score = NO_ANSWER
            answers_with_score = {}
            for question in questionnaire['questions']:
                if question['id'] in skill_levels:
                    score = skill_levels[question['id']]
                    if score > max_score:
                        max_score = score
                    if score not in answers_with_score:
                        answers_with_score[score] = 0
                    answers_with_score[score] += 1
            reported_max_score = None
            if max_score != NO_ANSWER:
                reported_max_score = max_score
            result[questionnaire['id']] = {
                'skills': scores_to_skills(answers_with_score),
                'max_score': reported_max_score
            }

        return result

    @hybrid_property
    def locales(self):
        '''
        Convenient list of locales for this user.
        '''
        return [l.locale for l in self.languages]

    @locales.setter
    def _languages_setter(self, values):
        '''
        Update locales for this user in bulk.  Values are an array of language
        codes.
        '''
        locale_codes = [l.language for l in self.locales]
        # only add new languages
        for val in values:
            if val not in locale_codes:
                db.session.add(UserLanguage(locale=val, user_id=self.id))

        # delete languages no longer found
        languages_to_remove = []
        for lan in self.languages:
            if lan.locale.language not in values:
                languages_to_remove.append(lan)

        for lan in languages_to_remove:
            self.languages.remove(lan)

    __table_args__ = (UniqueConstraint('deployment', 'email'), )
Exemple #14
0
df['HOSPITAL'] = df['HP_ID'] + " " + df['HP_NAME']
df['STORE'] = df['STORE_ID'] + " " + df['STORE_NAME']
df['RM_POS_NAME'] = df['RM'] + " " + df['RM_NAME']
df['DSM_POS_NAME'] = df['DSM'] + " " + df['DSM_NAME']
df['RSP_POS_NAME'] = df['RSP'] + " " + df['RSP_NAME']
print(df)

print("start importing...")
df.to_sql(
    "data",
    con=engine,
    if_exists="replace",
    index=False,
    dtype={
        "YEAR": t.INTEGER(),
        "DATE": t.DateTime(),
        "MONTH": t.INTEGER(),
        "QUARTER": t.INTEGER(),
        "HP_ID": t.NVARCHAR(length=10),
        "HP_NAME": t.NVARCHAR(length=100),
        "HOSPITAL": t.NVARCHAR(length=110),
        "STORE_ID": t.NVARCHAR(length=10),
        "STORE_NAME": t.NVARCHAR(length=100),
        "STORE": t.NVARCHAR(length=110),
        "PROVINCE": t.NVARCHAR(length=3),
        "CITY": t.NVARCHAR(length=30),
        "COUNTY": t.NVARCHAR(length=30),
        "LEVEL": t.NVARCHAR(length=4),
        "IF_COMMUNITY": t.Boolean(),
        "IF_DUALCALL": t.Boolean(),
        "PRODUCT": t.NVARCHAR(length=10),
Exemple #15
0
    def create(
        cls,
        user_version,
        user_version_nullable=False,
        table_name='alembic_version_history',
        metadata=None,
        extra_columns=(),
        user_version_column_name='user_version',
        user_version_type=types.String(255),
        direction_column_name='operation_direction',
        operation_column_name='operation_type',
        alembic_version_separator='##',
        alembic_version_column_name='alembic_version',
        prev_alembic_version_column_name='prev_alembic_version',
        change_time_column_name='changed_at',
    ):
        """Autocreate a history table.

        This table contains columns for:

        * user version
        * alembic version(s) prior to upgrade
        * alembic version(s) after upgrade
        * :paramref:`operation type <.Auditor.create.operation_column_name>`
        * :paramref:`operation direction <.Auditor.create.direction_column_name>`
        * operation direction
        * upgrade time

        The user may add their own columns and, to some extent, customize
        those provided. See the parameter list for details.

        :param user_version: a constant value or callable giving the user
            version to be stored with each migration step. If callable, it
            accepts the kwargs provided by alembic's on_version_apply
            callback. A good value for this might be your application's git
            version or current version tag.

        .. note::

            :paramref:`.Auditor.create.user_version` does not have to be
            provided. It can be null. It is intended to tie an alembic version
            to a specific point in your version control, so that you may
            consult that history and know exactly the content of the patch
            that was executed. Failing to include this information may
            seriously dilute the value of keeping these records at all. For
            that reason, we highly recommend providing *something* here.

            If you pass ``None``, a warning will be raised.

        :param user_version_nullable: Suppresses the above-mentioned warning.
        :param table_name: The name of the version history table.
        :param metadata: The SQLAlchemy MetaData object with which the table
            is to be created. If not provided, a new one will be created.
        :param extra_columns: A sequence of extra columns to add to the
            default table. Each element of the column is a 2-tuple
            ``(col, val)`` where ``col`` is a SQLAlchemy ``Column`` and
            ``val`` is a value for it, expressed the same way as
            :paramref:`~.Auditor.create.user_version`: as a constant,
            type-appropriate value, or a function of kwargs returning such a
            value.
        :param user_version_column_name: the name used for the column
            storing the value of :paramref:`~.Auditor.create.user_version`.
        :param user_version_type: the SQL type of
            :paramref:`~.Auditor.create.user_version`. If not specified, this
            is assumed to be VARCHAR(32).
        :param operation_column_name: the name of the column storing operation
            type. Currently supported values are ``migrate`` and ``stamp``,
            which indicate respectively that database changes are made, or
            that the version is changed without effecting any true changes.
            The field is nonnullable but unconstrained in case future Alembic
            versions support other migration types.
        :param direction_column_name: the name of the column storing the
            operation direction. This column is an enum (native on backends that
            support it, or given as a varchar with constraints) of the string
            values ``up`` and ``down``. It is left nullable in case future
            Alembic versions support migration types without an up/down
            direction.
        :param alembic_version_column_name: The name of the column storing
            the "new" (i.e. after the operation is complete) alembic
            version(s) of the database. Note that this is distinct in theory
            from the "up" version of the migration operation, and distinct
            in practice when the operation is a downgrade.
        :param prev_alembic_version_column_name: The name of the column
            storing the "old" (i.e before the operation is complete) alembic
            version(s) of the database. Note that this is distinct in theory
            from the "down" version of the migration operation, and distinct
            in practice when the operation is a downgrade.
        :param alembic_version_separator: if multiple alembic versions are
            given for one of the alembic version columns, they are joined
            together with ``alembic_version_separator`` as the delimiter.
        :param change_time_column_name: the name of the column storing the
            time of this migration

        """
        if not user_version_nullable:
            if user_version is None:
                cls.version_warn()
            elif callable(user_version):
                orig_user_version = user_version

                def user_version(**kw):
                    val = orig_user_version(**kw)
                    if val is None:
                        cls.version_warn(stacklevel=1)
                    return val

        if metadata is None:
            metadata = MetaData()

        alembic_version_type = types.String(255)

        columns = [
            Column('id',
                   types.BIGINT().with_variant(types.Integer, 'sqlite'),
                   primary_key=True),
            Column(alembic_version_column_name, alembic_version_type),
            Column(prev_alembic_version_column_name, alembic_version_type),
            CheckConstraint('coalesce(%s, %s) IS NOT NULL' %
                            (alembic_version_column_name,
                             prev_alembic_version_column_name),
                            name='alembic_versions_nonnull'),
            Column(operation_column_name, types.String(32), nullable=False),
            Column(direction_column_name, types.String(32), nullable=False),
            Column(user_version_column_name, user_version_type),
            Column(change_time_column_name, types.DateTime())
        ]

        def alembic_vers(f):
            return functools.partial(f, separator=alembic_version_separator)

        col_vals = {
            alembic_version_column_name: alembic_vers(ccv.new_alembic_version),
            prev_alembic_version_column_name:
            alembic_vers(ccv.old_alembic_version),
            operation_column_name: ccv.operation_type,
            direction_column_name: ccv.operation_direction,
            user_version_column_name: user_version,
            change_time_column_name: ccv.change_time,
        }
        for col, val in extra_columns:
            columns.append(col)
            if col.name in col_vals:
                raise exc.AuditCreateError('value %s used twice' % col.name)
            col_vals[col.name] = val

        auditor = cls(Table(table_name, metadata, *columns), col_vals)
        return auditor
Exemple #16
0
class AuthUser(Base):
    """ Table name: auth_users

::

    id = Column(types.Integer(), primary_key=True)
    login = Column(Unicode(80), default=u'', index=True)
    _password = Column('password', Unicode(80), default=u'')
    email = Column(Unicode(80), default=u'', index=True)
    active = Column(types.Enum(u'Y',u'N',u'D'), default=u'Y')
    """
    __tablename__ = 'auth_users'
    __table_args__ = {"sqlite_autoincrement": True}

    id = Column(types.Integer(), primary_key=True)
    auth_id = Column(types.Integer, ForeignKey(AuthID.id), index=True)
    provider = Column(Unicode(80), default=u'local', index=True)
    login = Column(Unicode(80), default=u'', index=True)
    salt = Column(Unicode(24))
    _password = Column('password', Unicode(80), default=u'')
    email = Column(Unicode(80), default=u'', index=True)
    created = Column(types.DateTime(), default=func.now())
    active = Column(types.Enum(u'Y',u'N',u'D', name=u"active"), default=u'Y')

    def _set_password(self, password):
        self.salt = self.get_salt(24)
        password = password + self.salt
        self._password = BCRYPTPasswordManager().encode(password, rounds=12)

    def _get_password(self):
        return self._password

    password = synonym('_password', descriptor=property(_get_password, \
                       _set_password))

    def get_salt(self, length):
        m = hashlib.sha256()
        word = ''

        for i in xrange(length):
            word += random.choice(string.ascii_letters)

        m.update(word)

        return unicode(m.hexdigest()[:length])

    @classmethod
    def get_by_id(cls, id):
        """ 
        Returns AuthUser object or None by id

        .. code-block:: python

           from apex.models import AuthID

           user = AuthID.get_by_id(1)
        """
        return DBSession.query(cls).filter(cls.id==id).first()    

    @classmethod
    def get_by_login(cls, login):
        """ 
        Returns AuthUser object or None by login

        .. code-block:: python

           from apex.models import AuthUser

           user = AuthUser.get_by_login('login')
        """
        return DBSession.query(cls).filter(cls.login==login).first()

    @classmethod
    def get_by_email(cls, email):
        """ 
        Returns AuthUser object or None by email

        .. code-block:: python

           from apex.models import AuthUser

           user = AuthUser.get_by_email('*****@*****.**')
        """
        return DBSession.query(cls).filter(cls.email==email).first()

    @classmethod
    def check_password(cls, **kwargs):
        if kwargs.has_key('id'):
            user = cls.get_by_id(kwargs['id'])
        if kwargs.has_key('login'):
            user = cls.get_by_login(kwargs['login'])

        if not user:
            return False
        try:
            if BCRYPTPasswordManager().check(user.password,
                '%s%s' % (kwargs['password'], user.salt)):
                return True
        except TypeError:
            pass

        request = get_current_request()
        fallback_auth = request.registry.settings.get('apex.fallback_auth')
        if fallback_auth:
            resolver = DottedNameResolver(fallback_auth.split('.', 1)[0])
            fallback = resolver.resolve(fallback_auth)
            return fallback().check(DBSession, request, user, \
                       kwargs['password'])

        return False
Exemple #17
0
                                     landmines_df.longitude,
                                     landmines_df.latitude))

# Drop the lat/lng column
landmines_gdf = landmines_gdf.drop(['latitude', 'longitude'], axis=1)

# Reset the data frame's index
landmines_gdf = landmines_gdf.reset_index(drop=True)

# add to database
landmines_gdf.to_postgis(con=engine,
                         name='incidents',
                         if_exists='append',
                         dtype={
                             'info': types.Text(),
                             'datetime': types.DateTime(),
                             'geometry': geoTypes(geometry_type='POINT',
                                                  srid=4326)
                         })

################################################################################
#                                                                              #
#                            upload the imagery                                #
#                                                                              #
################################################################################

# list of imagey
imagery_paths = [os.path.join(figures_dir_path, 'kandahar-compressed.tif')]

# init imagery dict
imagery_dict = []
Exemple #18
0
class AuthID(Base):
    """ Table name: auth_id

::

    id = Column(types.Integer(), primary_key=True)
    display_name = Column(Unicode(80), default=u'')
    active = Column(types.Enum(u'Y',u'N',u'D', name=u"active"), default=u'Y')
    created = Column(types.DateTime(), default=func.now())

    """

    __tablename__ = 'auth_id'
    __table_args__ = {"sqlite_autoincrement": True}

    id = Column(types.Integer(), primary_key=True)
    display_name = Column(Unicode(80), default=u'')
    active = Column(types.Enum(u'Y',u'N',u'D', name=u"active"), default=u'Y')
    created = Column(types.DateTime(), default=func.now())

    groups = relationship('AuthGroup', secondary=auth_group_table, \
                      backref='auth_users')

    users = relationship('AuthUser')

    """
    Fix this to use association_proxy
    groups = association_proxy('auth_group_table', 'authgroup')
    """

    last_login = relationship('AuthUserLog', \
                         order_by='AuthUserLog.id.desc()')
    login_log = relationship('AuthUserLog', \
                         order_by='AuthUserLog.id')

    def in_group(self, group):
        """
        Returns True or False if the user is or isn't in the group.
        """
        return group in [g.name for g in self.groups]

    @classmethod
    def get_by_id(cls, id):
        """ 
        Returns AuthID object or None by id

        .. code-block:: python

           from apex.models import AuthID

           user = AuthID.get_by_id(1)
        """
        return DBSession.query(cls).filter(cls.id==id).first()    

    def get_profile(self, request=None):
        """
        Returns AuthUser.profile object, creates record if it doesn't exist.

        .. code-block:: python

           from apex.models import AuthUser

           user = AuthUser.get_by_id(1)
           profile = user.get_profile(request)

        in **development.ini**

        .. code-block:: python

           apex.auth_profile = 
        """
        if not request:
            request = get_current_request()

        auth_profile = request.registry.settings.get('apex.auth_profile')
        if auth_profile:
            resolver = DottedNameResolver(auth_profile.split('.')[0])
            profile_cls = resolver.resolve(auth_profile)
            return get_or_create(DBSession, profile_cls, user_id=self.id)

    @property
    def group_list(self):
        group_list = []
        if self.groups:
            for group in self.groups:
                group_list.append(group.name)
        return ','.join( map( str, group_list ) )
Exemple #19
0
class AbstractModel(Base):
    __abstract__ = True

    date_created = schema.Column(types.DateTime(timezone=True),
                                 default=get_utc_now)
Exemple #20
0
class User(db.Model, TimestampMixin, UserMixin):
    """User model class"""

    __tablename__ = 'users'
    id = db.Column(types.Integer, primary_key=True)
    username = db.Column(types.String(50), unique=True, nullable=False)
    email = db.Column(types.String(50), unique=True, nullable=False)
    _password = db.Column('password', types.String(64), nullable=False)
    is_active = db.Column(types.Boolean, nullable=False, default=True)
    last_login = db.Column(types.DateTime(timezone=True),
                           onupdate=datetime.utcnow())

    def __init__(self, username, email, password=None, **kwargs):
        """Create instance."""
        db.Model.__init__(self, username=username, email=email, **kwargs)
        if password:
            self.set_password(password)
        else:
            self.password = None

    def __repr__(self):
        return '<User {0}>'.format(self.username)

    def is_authenticated(self):
        return True

    def is_active(self):
        return self.is_active

    def is_anonymous(self):
        return False

    def get_id(self):
        return str(self.id)

    def _get_password(self):
        return self._password

    def _set_password(self, password):
        self._password = generate_password_hash(password)

    # Hide password encryption by exposing password field only.
    password = orm.synonym('_password',
                           descriptor=property(_get_password, _set_password))

    def check_password(self, password):
        if self.password is None:
            return False
        return check_password_hash(self.password, password)

    @classmethod
    def search(cls, keywords):
        criteria = []
        for keyword in keywords.split():
            keyword = '%' + keyword + '%'
            criteria.append(
                types.or_(
                    User.name.ilike(keyword),
                    User.email.ilike(keyword),
                ))
        q = reduce(types.and_, criteria)
        return cls.query.filter(q)
Exemple #21
0
def test_should_datetime_convert_string():
    assert_column_conversion(types.DateTime(), graphene.String)
Exemple #22
0
class Flight(Base):
    __tablename__ = 'flights'
    id = Column(Integer, primary_key=True)
    hexident = Column(String(6), nullable=False)
    callsign = Column(String(10))
    # gen_date_time timestamp of the first ADSb message of this hexiden processed
    first_seen = Column(types.DateTime(timezone=True), nullable=False)
    # gen_date_time timestamp of (any) last ADSb message of this hexident
    last_seen = Column(types.DateTime(timezone=True))
    # https://gis.stackexchange.com/questions/4467/how-to-handle-time-in-gis
    # flightpath = Column(Geometry('LINESTRINGZ', srid=SRID, dimension=3))
    intention = Column(Enum(Intention), default=Intention.unknown)

    #https://stackoverflow.com/questions/5033547/sqlalchemy-cascade-delete
    positions: Position = relationship('Position',
                                       backref=backref('flight', lazy=True),
                                       passive_deletes=True,
                                       order_by="asc(Position.time)")

    def __init__(self, hexident: string):
        self.hexident = hexident
        self.squawk = None
        self.__flightpath = []
        self.__times = []
        self._transmission_type_count = dict.fromkeys(range(1, 9, 1), 0)

        self._on_landing_subscribers = []
        self._on_takeoff_subscribers = []

        self._onground = None
        self._last_event = datetime.datetime.now(datetime.timezone.utc)

    def __str__(self):
        return "Flight {hexident}: last seen: {last_seen}".format(
            **self.__dict__)

    @property
    def age(self) -> datetime.timedelta:
        """
        Computes the age in seconds since last seen.
        :return: Age in seconds since last seen
        """
        return datetime.datetime.now(datetime.timezone.utc) - self.last_seen

    @property
    def interpolated_track(self):
        """Compute flight heading from last known 2 positions."""
        if len(self.positions) >= 2:
            return interpolate_track(self.positions[-2:])
        else:
            return None

    def update(self, adsb: adsb_parser.AdsbMessage):
        """
        Updates the instance attributes with values from an ADSb message object and returns.

        MSG types and contained info:
        - 1: callsign & onground
        - 2: speed & latitude & longitude & onground
        - 3: altitude & latitude & longitude
        - 4: speed & track & verticalrate & onground
        - 5: altitude OR altitude & vertical_rate OR altitude & speed & track
        - 6: (speed & track) (verticalrate) squawk & alert & emergency & spi
        - 7: altitude
        - 8: onground

        :param adsb: Instance of AdsbMessage
        :returns Updated version of self
        """

        # Upon landing MSG type changes from 3 to 2 (no altitude is transmitted after landing)
        MSG_FIELDS = {
            1: ('callsign', 'onground'),
            2: ('speed', 'latitude', 'longitude', 'onground'),
            3: ('altitude', 'latitude', 'longitude', 'onground'),
            4: ('speed', 'track', 'verticalrate', 'onground'),
            5: ('altitude', 'verticalrate'),
            8: ('onground', )
        }

        if adsb.hexident != self.hexident:
            log.error(
                "Trying to update flight '{}' with ADSb message of flight '{}'"
                .format(self.hexident, adsb.hexident))
            return self

        self._transmission_type_count[adsb.transmission_type] += 1

        if not self.first_seen:
            self.first_seen = adsb.gen_date_time

        # Note: last_seen timestamp gets updated from any MSG type, regardless whether the message content will be used
        # to update the object attributes or not.
        self.last_seen = adsb.gen_date_time

        # Process only message types defined as keys in MSG_FIELDS
        try:
            log.debug("Updating flight {} with MSG type: {}".format(
                self.hexident, adsb.transmission_type))
            for field in MSG_FIELDS[adsb.transmission_type]:
                setattr(self, field, getattr(adsb, field))
                log.debug("Updating field: {}={}".format(
                    field, getattr(adsb, field)))
        except KeyError:
            log.debug(
                "Skipping updating flight with transmission type {:d}: {}".
                format(adsb.transmission_type, adsb))

        # Update flight path geometry only if msg includes coordinates
        # ATTENTION: x: longitude (easting), y: latitude (northing)
        if adsb.transmission_type == 3:
            if adsb.longitude is not None and adsb.latitude is not None and adsb.altitude is not None:
                position = Position(time=adsb.gen_date_time,
                                    coordinates=from_shape(Point(
                                        adsb.longitude, adsb.latitude,
                                        feet2m(adsb.altitude)),
                                                           srid=SRID),
                                    onground=adsb.onground)
                self.positions.append(position)
                self.update_onground(adsb.onground)

            else:
                log.debug(
                    "Cannot update position as MSG3 did not include lon/lat: {}"
                    .format(str(adsb)))
        # First MSG2 of aircraft at terminal does not contain coordinates, only 'onground'
        # Also, the altitude is not included in MSG2, and is being set here to GND_ALTITUDE (0m AGL)
        elif adsb.transmission_type == 2 and adsb.longitude is not None and adsb.latitude is not None:
            self.positions.append(
                Position(time=adsb.gen_date_time,
                         coordinates=from_shape(Point(adsb.longitude,
                                                      adsb.latitude,
                                                      GND_ALTITUDE),
                                                srid=SRID),
                         onground=adsb.onground))
            self.update_onground(adsb.onground)

        return self

    def update_onground(self, onground):
        """Flip the onground attribute of this flight and broadcast event."""

        if self._onground is None:  # First position of new flight
            self._onground = onground
        elif self._onground and self._onground != onground:  # takeoff
            self._onground = onground
            self._broadcast_takeoff(self.positions[-1])
        elif not self._onground and self._onground != onground:  # landing
            self._onground = onground
            self._broadcast_landing(self.positions[-1])

    def register_on_landing(self, subscriber):
        """Register an on-landing subscriber."""
        self._on_landing_subscribers.append(subscriber)

    def register_on_takeoff(self, subscriber):
        """Register an on-takeoff subscriber."""
        self._on_takeoff_subscribers.append(subscriber)

    def _broadcast_landing(self, position):
        """Call the callback of landing subscribers."""
        if self._valid_event(position):
            for subscriber in self._on_landing_subscribers:
                subscriber(position, self)
        else:
            log.warning(
                "Suppressing landing event broadcast for flight {} at {}".
                format(position.flight_id, position.time))

    def _broadcast_takeoff(self, position):
        """Call the callback of takeoff subscribers."""
        if self._valid_event(position):
            for subscriber in self._on_takeoff_subscribers:
                subscriber(position, self)
        else:
            log.warning(
                "Suppressing landing event broadcast for flight {} at {}".
                format(position.flight_id, position.time))

    def _valid_event(self, position):
        """Return true if the identified event (landing or takeoff) lies sufficiently apart from previous event.

        Sometimes a landing airplane 'bounces' off the runway triggering several flips between onground statuses.
        This leads to identification of several landing and takeoff events within 1-2sec. To prevent this, the position
        time stamp of the current event is compared to any previous event, and if the time difference is below a short
        threshold the new event should be discarded In this case False is returned by this method.

        The new event timestamp overwrites the last event timestamp, as the time difference between the bounces matters,
        and not between first and any following bounce.
        """

        valid = position.time - self._last_event > datetime.timedelta(
            seconds=2.0)
        self._last_event = position.time
        return valid
Exemple #23
0
"""샘플쿼리 모듈
:filename:          - query.py
:modified:          - 2017.08.24
:note:              - 이 모듈에서는 자주사용하는 샘플쿼리를 미리 정의함

"""
'''모듈 불러오기'''
from sqlalchemy import types  #ALCHEMY for engine
'''쿼리 인스턴스 임포트'''
#기본상품정보조회
#컬럼 데이터타입
basic_col = \
 {'ISIN_NO'  :types.NVARCHAR(length=50),
    'STD_DATE'  :types.DateTime(),
    'FIRST_AMT' :types.BigInteger(),
    'REMAIN_AMT' :types.BigInteger(),
    'EFF_DATE'  :types.DateTime(),
    'MAT_DATE'  :types.DateTime(),
    'PRSV_RATE' :types.Float()}
#쿼리문
basic_sql = \
 (
    "select ISIN_NO,"
    "		to_date(STD_DATE,'yyyymmdd') STD_DATE, "
    "		FIRST_AMT, REMAIN_AMT, "
    "		EFF_DATE, MAT_DATE, PRSV_RATE "
    "from "
    "( "
    "	select 	tblLATEST.ISIN_NO, "                    #ISIN번호
    "			greatest(tblLATEST.STND_DATE, nvl(tblREFUND.STND_DATE,0)) STD_DATE, "      #처리일자
    "			tblLATEST.FIRST_AMT/1000000 FIRST_AMT, "              #최초발행금액 (백만원)
Exemple #24
0
class Content(AutoTimestampMixin, BaseContentModel):
    """
    Content
    """
    URI_TYPE = UriType.PRIMARY

    _title = Column(types.String(512))
    _author_names = Column(types.String(128))  # Model TODO: make 256

    # TODO: split out ContentPublication as 1:1?
    _publication = Column(types.String(64))  # e.g. Nature
    _published_timestamp = Column(types.DateTime(), index=True)  # UTC
    _granularity_published = Column(types.SmallInteger())
    _tzinfo_published = Column(types.String(64))
    publisher = Column(types.String(64))  # e.g. Macmillan Science & Education
    summary = Column(types.Text())
    full_text = Column(types.Text())
    # problems
    # orgs
    # geos
    # communities
    # ratings
    # comments

    __table_args__ = (
        Index(
            'ux_content',
            # ux for unique index
            '_title',
            '_author_names',
            '_publication',
            '_published_timestamp',
            unique=True), )

    Key = namedtuple(
        'ContentKey',  # add publisher?
        'title, author_names, publication, published_timestamp')

    @classmethod
    def create_key(cls, title, author_names, publication, published_timestamp,
                   **kwds):
        """Create Trackable key"""
        lowered_title = title.lower()
        normalized_authors = cls.normalize_author_names(author_names)
        flex_dt = FlexTime.cast(published_timestamp)
        dt_utc = flex_dt.astimezone(UTC).deflex(native=True)
        return cls.Key(lowered_title, normalized_authors, publication, dt_utc)

    def derive_key(self, **kwds):
        """Derive Trackable key from instance"""
        return self.model_class.Key(
            self.title.lower(), self.author_names, self.publication,
            self.published_timestamp.astimezone(UTC).deflex(native=True))

    @property
    def title(self):
        return self._title.capitalize() if self._title else self._title

    @title.setter
    def title(self, val):
        if val is None:
            raise ValueError('Cannot be set to None')
        # During __init__()
        if self._title is None:
            self._title = val.lower()
            return
        # Not during __init__()
        key = self.model_class.create_key(
            val, self.author_names, self.publication,
            self.published_timestamp.astimezone(UTC))
        self.register_update(key)

    title = orm.synonym('_title', descriptor=title)

    @property
    def author_names(self):
        return self._author_names

    @author_names.setter
    def author_names(self, val):
        if val is None:
            raise ValueError('Cannot be set to None')
        normalized_val = self.normalize_author_names(val)
        # During __init__()
        if self._author_names is None:
            self._author_names = normalized_val
            return
        # Not during __init__()
        key = self.model_class.create_key(
            self.title, normalized_val, self.publication,
            self.published_timestamp.astimezone(UTC))
        self.register_update(key)

    author_names = orm.synonym('_author_names', descriptor=author_names)

    @classmethod
    def normalize_author_names(cls, author_names):
        author_name_list = [
            cls.normalize_author_name(an) for an in author_names.split(';')
        ]
        return '; '.join(author_name_list)

    @classmethod
    def normalize_author_name(cls, author_name):
        author_name = author_name.strip()
        author_name_components = [c.strip() for c in author_name.split(',')]
        credential = None
        if (len(author_name_components) > 1
                and cls.is_credential(author_name_components[-1])):
            credential = author_name_components[-1]
            author_name_components = author_name_components[:-1]

        if len(author_name_components) == 1:
            author_name_components = [
                c.strip() for c in author_name_components[0].split()
            ]
            if (len(author_name_components) > 1
                    and cls.is_credential(author_name_components[-1])):
                credential = author_name_components[-1]
                author_name_components = author_name_components[:-1]
            author_name = ' '.join(author_name_components)

        elif len(author_name_components) == 2:
            last_name, first_name = author_name_components
            author_name = ' '.join((first_name.strip(), last_name.strip()))
        else:
            raise ValueError('Invalid author name')

        author_name = author_name.title()

        if credential:
            author_name = ', '.join((author_name, credential))

        return author_name

    @classmethod
    def is_credential(cls, component):
        return (component.upper() == component or component == 'PhD')

    @property
    def publication(self):
        return self._publication

    @publication.setter
    def publication(self, val):
        if val is None:
            val = ''
        # During __init__()
        if self._publication is None:
            self._publication = val
            return
        # Not during __init__()
        key = self.model_class.create_key(
            self.title, self.author_names, val,
            self.published_timestamp.astimezone(UTC))
        self.register_update(key)

    publication = orm.synonym('_publication', descriptor=publication)

    @property
    def published_timestamp(self):
        flex_dt = FlexTime.instance(self._published_timestamp)
        localized = flex_dt.astimezone(self.tzinfo_published)
        return FlexTime.instance(localized,
                                 granularity=self.granularity_published,
                                 truncate=True)

    @published_timestamp.setter
    def published_timestamp(self, val):
        """Set published timestamp given FlexTime datetime or TimestampInfo"""
        self.set_published_timestamp_info(val)

    published_timestamp = orm.synonym('_published_timestamp',
                                      descriptor=published_timestamp)

    @property
    def granularity_published(self):
        return FlexTime.Granularity(self._granularity_published)

    granularity_published = orm.synonym('_granularity_published',
                                        descriptor=granularity_published)

    @property
    def tzinfo_published(self):
        return self._tzinfo_published

    tzinfo_published = orm.synonym('_tzinfo_published',
                                   descriptor=tzinfo_published)

    @property
    def published_timestamp_info(self):
        """Get publication datetime info namedtuple"""
        return self.published_timestamp.info

    jsonified_published_timestamp_info = JsonProperty(
        name='published_timestamp_info', after='tzinfo_published')

    def set_published_timestamp_info(self, dt, granularity=None, geo=None):
        """
        Set publication timestamp info

        Set all fields related to published timestamp. This is the only
        method that actually sets the private member variables.

        Content may be published with varying levels of granularity,
        ranging from year to microseconds. This is supported by the
        FlexTime factory class which endows datetime instances with
        granularity and info members.

        I/O:
        dt: FlexTime or regular datetime instance or DatetimeInfo tuple
        granularity=None: FlexTime granularity or associated int value
        geo=None: geo where the content was originally published
        """

        # TODO: if geo and dt is naive, set timezone based on geo

        flex_dt = FlexTime.cast(dt, granularity)
        granularity, info = flex_dt.granularity, flex_dt.info

        now = pendulum.now(UTC)
        if flex_dt > now:
            raise ValueError('Publication date may not be in the future')

        # During __init__()
        if self._published_timestamp is None:
            self._published_timestamp = (flex_dt.astimezone(UTC).deflex(
                native=True))
        else:  # Not during __init__()
            key = self.model_class.create_key(self.title, self.author_names,
                                              self.publication, flex_dt)
            self.register_update(key)

        self._granularity_published = granularity.value
        self._tzinfo_published = info.tzinfo

    def __init__(self,
                 title,
                 author_names,
                 publication,
                 published_timestamp,
                 granularity_published=None,
                 geo=None,
                 publisher=None,
                 summary=None,
                 full_text=None):
        self.set_published_timestamp_info(published_timestamp,
                                          granularity_published, geo)
        self.title = title
        self.author_names = author_names
        self.publication = publication
        self.publisher = publisher
        self.summary = summary
        self.full_text = full_text
Exemple #25
0
 def load_dialect_impl(self, dialect):
     if dialect.name == "sqlite":
         return dialect.type_descriptor(CHAR(25))
     else:
         return dialect.type_descriptor(types.DateTime())
Exemple #26
0
def random_string():
    import string
    import random
    return ''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(12))


metadata = schema.MetaData()


tags_table = schema.Table('tags', metadata,
    schema.Column('id', types.Integer, primary_key=True),
    #schema.Column('image_id', types.Integer, nullable=False),    
    schema.Column('tag', types.Unicode(255), nullable=False, unique=True),
    #schema.Column('relevancy', types.Integer, default=0),
    schema.Column('created_at', types.DateTime(), default=now()),
)

image_table = schema.Table('images', metadata,
    schema.Column('id', types.Integer, primary_key=True), #this would be crc32 of absolute url of the image
    schema.Column('image_holder', types.Unicode(255)), #web page where image is. also written as 'identifier' at some places
    schema.Column('image', types.Unicode(255), nullable=False), #absolute url of image
    schema.Column('width', types.Integer),
    schema.Column('height', types.Integer),
    schema.Column('rights', types.Unicode(255)),
    schema.Column('creator', types.Unicode(255)),
    schema.Column('source', types.Unicode(255)),  
    schema.Column('created_at', types.DateTime(), default=now()),
)

relation_table = schema.Table('relation', metadata,
def test_should_datetime_convert_datetime():
    assert get_field(types.DateTime()).type == DateTime
Exemple #28
0
class utcnow(expression.FunctionElement):
    type = types.DateTime()
Exemple #29
0
class Allocation(TimestampMixin, ORMBase, OtherModels):
    """Describes a timespan within which one or many timeslots can be
    reserved.

    There's an important concept to understand before working with allocations.
    The resource uuid of an alloction is not always pointing to the actual
    resource.

    A resource may in fact be a real resource, or an imaginary resource with
    a uuid derived from the real resource. This is a somewhat historical
    artifact.

    If you need to know which allocations belong to a real resource, the
    mirror_of field is what's relevant. The originally created allocation
    with the real_resource is also called the master-allocation and it is
    the one allocation with mirror_of and resource being equal.

    When in doubt look at the managed_* functions of seantis.reservation.db's
    Scheduler class.

    """

    __tablename__ = 'allocations'

    id = Column(types.Integer(), primary_key=True, autoincrement=True)
    resource = Column(customtypes.GUID(), nullable=False)
    mirror_of = Column(customtypes.GUID(), nullable=False)
    group = Column(customtypes.GUID(), nullable=False)
    quota = Column(types.Integer(), default=1)
    partly_available = Column(types.Boolean(), default=False)
    approve_manually = Column(types.Boolean(), default=False)

    reservation_quota_limit = Column(
        types.Integer(), default=0, nullable=False
    )

    # The dates are stored without any timzone information (unaware).
    # Therefore the times are implicitly stored in the timezone the resource
    # resides in.

    # This is fine and dandy as long as all resources are in the same timezone.
    # If they are not problems arise. So in the future the resource should
    # carry a timezone property which is applied to the dates which will then
    # be stored in UTC

    # => TODO
    _start = Column(types.DateTime(), nullable=False)
    _end = Column(types.DateTime(), nullable=False)
    _raster = Column(types.Integer(), nullable=False)

    recurrence_id = Column(types.Integer(),
                           ForeignKey('recurrences.id',
                                      onupdate='cascade',
                                      ondelete='cascade'))
    recurrence = relation('Recurrence', lazy='joined')

    __table_args__ = (
        Index('mirror_resource_ix', 'mirror_of', 'resource'),
        UniqueConstraint('resource', '_start', name='resource_start_ix')
    )

    def copy(self):
        allocation = Allocation()
        allocation.resource = self.resource
        allocation.mirror_of = self.mirror_of
        allocation.group = self.group
        allocation.quota = self.quota
        allocation.partly_available = self.partly_available
        allocation.approve_manually = self.approve_manually
        allocation._start = self._start
        allocation._end = self._end
        allocation._raster = self._raster
        allocation.recurrence_id = self.recurrence_id
        return allocation

    def get_start(self):
        return self._start

    def set_start(self, start):
        self._start = rasterize_start(start, self.raster)

    start = property(get_start, set_start)

    def get_end(self):
        return self._end

    def set_end(self, end):
        self._end = rasterize_end(end, self.raster)

    end = property(get_end, set_end)

    def get_raster(self):
        return self._raster

    def set_raster(self, raster):
        # the raster can only be set once!
        assert(not self._raster)
        self._raster = raster

    raster = property(get_raster, set_raster)

    @property
    def display_start(self):
        """Does nothing but to form a nice pair to display_end."""
        return self.start

    @property
    def display_end(self):
        """Returns the end plus one microsecond (nicer display)."""
        return self.end + timedelta(microseconds=1)

    @property
    def whole_day(self):
        """True if the allocation is a whole-day allocation.

        A whole-day allocation is not really special. It's just an allocation
        which starts at 0:00 and ends at 24:00 (or 23:59:59'999).

        As such it can actually also span multiple days, only hours and minutes
        count.

        The use of this is to display allocations spanning days differently.
        """

        s, e = self.display_start, self.display_end
        assert s != e  # this can never be, except when caused by cosmic rays

        return utils.whole_day(s, e)

    def overlaps(self, start, end):
        """ Returns true if the current timespan overlaps with the given
        start and end date.

        """
        start, end = rasterize_span(start, end, self.raster)
        return utils.overlaps(start, end, self.start, self.end)

    def contains(self, start, end):
        """ Returns true if the current timespan contains the given start
        and end date.

        """
        start, end = rasterize_span(start, end, self.raster)
        return self.start <= start and end <= self.end

    def free_slots(self, start=None, end=None):
        """ Returns the slots which are not yet reserved."""
        reserved = [slot.start for slot in self.reserved_slots]

        slots = []
        for start, end in self.all_slots(start, end):
            if not start in reserved:
                slots.append((start, end))

        return slots

    def align_dates(self, start=None, end=None):
        """ Aligns the given dates to the start and end date of the
        allocation.

        """

        start = start or self.start
        start = start < self.start and self.start or start

        end = end or self.end
        end = end > self.end and self.end or end

        return start, end

    def all_slots(self, start=None, end=None):
        """ Returns the slots which exist with this timespan. Reserved or free.

        """
        start, end = self.align_dates(start, end)

        if self.partly_available:
            for start, end in iterate_span(start, end, self.raster):
                yield start, end
        else:
            yield self.start, self.end

    def is_available(self, start=None, end=None):
        """ Returns true if the given daterange is completely available. """

        if not (start and end):
            start, end = self.start, self.end

        assert(self.overlaps(start, end))

        if self.is_blocked(start, end):
            return False

        reserved = [slot.start for slot in self.reserved_slots]
        for start, end in self.all_slots(start, end):
            if start in reserved:
                return False

        return True

    def is_blocked(self, start=None, end=None):
        if not (start and end):
            start, end = self.start, self.end
        else:
            start, end = utils.as_machine_date(start, end)

        BlockedPeriod = self.models.BlockedPeriod
        query = self._query_blocked_periods()
        query = query.filter(BlockedPeriod.start <= end)
        query = query.filter(BlockedPeriod.end >= start)

        return query.first() is not None

    def _query_blocked_periods(self):
        query = Session.query(self.models.BlockedPeriod)
        query = query.filter_by(resource=self.resource)
        return query

    @property
    def pending_reservations(self):
        """ Returns the pending reservations query for this allocation.
        As the pending reservations target the group and not a specific
        allocation this function returns the same value for masters and
        mirrors.

        """
        Reservation = self.models.Reservation
        query = Session.query(Reservation.id)
        query = query.filter(Reservation.target == self.group)
        query = query.filter(Reservation.status == u'pending')

        return query

    @property
    def waitinglist_length(self):
        return self.pending_reservations.count()

    @property
    def availability(self):
        """Returns the availability in percent."""

        if self.partly_available:
            total = sum(1 for s in self.all_slots())
        else:
            total = 1

        count = len(self.reserved_slots)
        for blocked_period in self._query_blocked_periods():
            count += len(list(iterate_span(blocked_period.start,
                                           blocked_period.end,
                                           self.raster)))

        if total == count:
            return 0.0

        if count == 0:
            return 100.0

        return 100.0 - (float(count) / float(total) * 100.0)

    @property
    def in_group(self):
        """True if the event is in any group."""

        query = Session.query(Allocation.id)
        query = query.filter(Allocation.resource == self.resource)
        query = query.filter(Allocation.group == self.group)
        query = query.limit(2)

        return len(query.all()) > 1

    @property
    def in_recurrence(self):
        """True if the event is attached to a recurrence."""

        return self.recurrence_id is not None

    @property
    def is_separate(self):
        """True if available separately (as opposed to available only as
        part of a group)."""
        if self.partly_available:
            return True

        if self.in_group:
            return False

        return True

    def availability_partitions(self, scheduler):
        """Partitions the space between start and end into blocks of either
        free, blocked or reserved time. Each block has a percentage
        representing the space the block occupies compared to the size of the
        whole allocation.

        The blocks are ordered from start to end. Each block is an item with
        two values. The first being the percentage, the second being the type.
        The type can be one of None, 'reserved' or 'blocked'.

        So given an allocation that goes from 8 to 9 and a reservation that
        goes from 8:15 until 8:30 and a block that goes from 8:30 to 9:00
        we get the following blocks:

        [
            (25%, None),
            (25%, 'reserved'),
            (50%, 'blocked')
        ]

        This is useful to divide an allocation block into different divs on the
        frontend, indicating to the user which parts of an allocation are
        available for reservation.

        Makes sure to only display slots that are within it's resources
        first_hour/last_hour timespan.

        """

        resource = get_resource_by_uuid(scheduler.uuid).getObject()
        min_start_resource = datetime.combine(self.start,
                                              time(resource.first_hour))
        max_end_resource = datetime.combine(self.end,
                                            time(resource.last_hour))

        display_start = max(min_start_resource, self.start)
        display_end = min(max_end_resource, self.end)

        reserved = dict((r.start, r) for r in self.reserved_slots if
                        r.start >= display_start and r.end <= display_end)
        blocked = set()
        for blocked_period in self._query_blocked_periods():
            blocked.update(start for start, end in
                           iterate_span(max(blocked_period.start,
                                            display_start),
                                        min(blocked_period.end,
                                            display_end),
                                        self.raster))

        if not (reserved or blocked):
            return [(100.0, None)]

        # Get the percentage one slot represents
        slots = list(self.all_slots(display_start, display_end))
        step = 100.0 / float(len(slots))

        # Create an entry for each slot with either True or False
        pieces = []
        for slot in slots:
            piece = None
            if slot[0] in reserved:
                reserved_slot = reserved[slot[0]]
                token = reserved_slot.reservation_token
                reservation = scheduler.reservation_by_token(token).one()
                piece = ('reserved', reservation.description, reservation.id)
            elif slot[0] in blocked:
                piece = ('blocked', None)
            pieces.append(piece)

        # Group by the None/'reserved'/'blocked' values in the pieces and sum
        # up the percentage
        partitions = []
        for flag, group in groupby(pieces, key=lambda p: p):
            percentage = len(list(group)) * step
            partitions.append([percentage, flag])

        # Make sure to get rid of floating point rounding errors
        total = sum([p[0] for p in partitions])
        diff = 100.0 - total
        partitions[-1:][0][0] -= diff

        return partitions

    @property
    def is_transient(self):
        """True if the allocation does not exist in the database, and is not
        about to be written to the database. If an allocation is transient it
        means that the given instance only exists in memory.

        See:
        http://www.sqlalchemy.org/docs/orm/session.html
        #quickie-intro-to-object-states
        http://stackoverflow.com/questions/3885601/
        sqlalchemy-get-object-instance-state

        """

        return object_session(self) is None and not has_identity(self)

    @property
    def is_master(self):
        """True if the allocation is a master allocation."""

        return self.resource == self.mirror_of

    def siblings(self, imaginary=True):
        """Returns the master/mirrors group this allocation is part of.

        If 'imaginary' is true, inexistant mirrors are created on the fly.
        those mirrors are transient (see self.is_transient)

        """

        # this function should always have itself in the result
        if not imaginary and self.is_transient:
            assert False, \
                'the resulting list would not contain this allocation'

        if self.quota == 1:
            assert(self.is_master)
            return [self]

        query = Session.query(Allocation)
        query = query.filter(Allocation.mirror_of == self.mirror_of)
        query = query.filter(Allocation._start == self._start)

        existing = dict(((e.resource, e) for e in query))

        master = self.is_master and self or existing[self.mirror_of]
        existing[master.resource] = master

        uuids = utils.generate_uuids(master.resource, master.quota)
        imaginary = imaginary and (master.quota - len(existing)) or 0

        siblings = [master]
        for uuid in uuids:
            if uuid in existing:
                siblings.append(existing[uuid])
            elif imaginary > 0:
                allocation = master.copy()
                allocation.resource = uuid
                siblings.append(allocation)

                imaginary -= 1

        return siblings
Exemple #30
0
df1 = df_volume[df_volume['TEMP'].str.isnumeric()]
df2 = df_volume[df_volume['TEMP'].str.isnumeric() == False]
df1['TEMP'] = df1['TEMP'].apply(np.int64)
df1['AMOUNT'] = df1['AMOUNT'] * df1['TEMP']
df_volume = pd.concat([df1, df2])
df_volume.drop('TEMP', axis=1, inplace=True)
df_combined = pd.concat([df, df_volume])
print(df_combined)

print('start importing...')
df_combined.to_sql('data',
                   con=engine,
                   if_exists='replace',
                   index=False,
                   dtype={
                       'DATE': t.DateTime(),
                       'AMOUNT': t.FLOAT(),
                       'TC I': t.NVARCHAR(length=200),
                       'TC II': t.NVARCHAR(length=200),
                       'TC III': t.NVARCHAR(length=200),
                       'TC IV': t.NVARCHAR(length=200),
                       'MOLECULE': t.NVARCHAR(length=200),
                       'PRODUCT': t.NVARCHAR(length=200),
                       'PACKAGE': t.NVARCHAR(length=200),
                       'CORPORATION': t.NVARCHAR(length=200),
                       'MANUF_TYPE': t.NVARCHAR(length=20),
                       'FORMULATION': t.NVARCHAR(length=50),
                       'STRENGTH': t.NVARCHAR(length=20),
                       'UNIT': t.NVARCHAR(length=25),
                       'PERIOD': t.NVARCHAR(length=3),
                       'MOLECULE_TC': t.NVARCHAR(length=255),