コード例 #1
0
ファイル: test_objects.py プロジェクト: gitfred/fuel-web
    def test_removing_from_cluster(self):
        self.env.create(cluster_kwargs={}, nodes_kwargs=[{"role": "controller"}])
        node_db = self.env.nodes[0]
        node2_db = self.env.create_node()
        objects.Node.remove_from_cluster(node_db)
        self.assertEqual(node_db.cluster_id, None)
        self.assertEqual(node_db.roles, [])
        self.assertEqual(node_db.pending_roles, [])

        exclude_fields = [
            "group_id",
            "id",
            "hostname",
            "fqdn",
            "mac",
            "meta",
            "name",
            "agent_checksum",
            "uuid",
            "timestamp",
            "nic_interfaces",
            "attributes",
        ]
        fields = set(c.key for c in sqlalchemy_inspect(objects.Node.model).attrs) - set(exclude_fields)

        for f in fields:
            self.assertEqual(getattr(node_db, f), getattr(node2_db, f))
コード例 #2
0
    def has_translation(self, language: str) -> 'bool':
        if not language:
            raise ValueError('language')

        if sqlalchemy_inspect(self).transient:
            raise ValueError('Translatable is transient and cannot have translations')

        if language == self.original_language:
            return True

        return db.session.query(Translation.query.filter_by(translatable_id=self.id,
                                                            language=language).exists()).scalar()
コード例 #3
0
    def _merge_task_properties_and_drm_options(self, task, drm_options):
        drm_options = dict(drm_options)
        task_state = sqlalchemy_inspect(task)

        for drm_option_name, task_property in self.drm_options_to_task_properties.iteritems(
        ):
            task_value = task_state.attrs[task_property.key].value

            if task_value:
                drm_options[drm_option_name] = task_value

        return drm_options
コード例 #4
0
    def add_translation(self, language: str, text: str, provider: str = None) -> 'Translation':
        if sqlalchemy_inspect(self).transient:
            raise ValueError('Translatable is transient and cannot have translations')

        if language == self.original_language:
            return self._get_dummy_translation()

        translation = Translation.query.filter_by(translatable_id=self.id, language=language).first()

        if translation is None:
            translation = Translation(self.id, language, text, provider)
            db.session.add(translation)

        return translation
コード例 #5
0
def row_to_dict(r, keep_relationships = False):
    '''Converts an SQLAlchemy record to a Python dict. We assume that _sa_instance_state exists and is the only value we do not care about.
       If DeclarativeBase is passed then all DeclarativeBase objects (e.g. those created by relationships) are also removed.
    '''
    d = {}
    if not keep_relationships:
        # only returns the table columns
        t = r.__table__
        for c in [c.name for c in list(sqlalchemy_inspect(t).columns)]:
            d[c] = getattr(r, c)
        return d
    else:
        # keeps all objects including those of type DeclarativeBase or InstrumentedList and the _sa_instance_state object
        return copy.deepcopy(r.__dict__)
コード例 #6
0
    def _merge_task_properties_and_drm_options(self, task, drm_options):
        drm_options = dict(drm_options)
        task_state = sqlalchemy_inspect(task)

        for drm_option_name, task_mapping in list(self.drm_options_to_task_properties.items()):
            if callable(task_mapping):
                task_value = task_mapping(task)
            else:
                task_value = task_state.attrs[task_mapping.key].value

            if task_value:
                # Translate cosmos memory requirements (in Kilobytes) to k8s-jobs memory (in bytes)
                if drm_option_name == "memory":
                    drm_options[drm_option_name] = str(task_value) + "K"
                else:
                    drm_options[drm_option_name] = task_value

        return drm_options
コード例 #7
0
	def __init__(self, logger, settings=None, globalSettings=None, env=None, poolSize=35, maxOverflow=5, poolRecycle=1800):
		"""Constructor for the database client connection.

		Arguments:
		  logger                : handler for the database log
		  settings (json)       : database connection parameters
		  globalSettings (json) : global settings
		  env                   : module containing the env paths
		  poolSize (int)        : pool_size param for sqlalchemy.create_engine()
		  maxOverflow (int)     : max_overflow param for sqlalchemy.create_engine()
		  poolRecycle (int)     : pool_recycle param for sqlalchemy.create_engine()

		"""
		try:
			self.session = None
			self.settings = settings
			if self.settings is None:
				self.settings = {}

			## Load an external database library for the wrapped connection;
			## the provided library must contain a 'createEngine' function that
			## takes a previously created dictionary with the database
			## settings (if empty, it will be filled), the globalSettings which
			## contains named references to modules in the external directory,
			## and parameters to use in the call for sqlalchemy.create_engine().	
			externalLibrary = utils.loadExternalLibrary('externalDatabaseLibrary', env, globalSettings)
			self.engine = externalLibrary.createEngine(self.settings, globalSettings, poolSize, maxOverflow, poolRecycle)

			## Validate we are actually talking to the DB
			engineInspector = sqlalchemy_inspect(self.engine)
			schemas = engineInspector.get_schema_names()
			if schemas is None or len(schemas) <= 0:
				raise EnvironmentError('Could not connect to database')
			logger.debug('Schemas found from connection pool inspection: {}'.format(str(schemas)))
			self.createScopedSession(logger)
			
		except exc.OperationalError:
			## Intentionally catch database connection errors
			logger.error('Exception in DatabaseClient: {}'.format(str(sys.exc_info()[1])))
			raise
		except:
			stacktrace = traceback.format_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
			logger.error('Exception in DatabaseClient: {}'.format(str(stacktrace)))
			raise
コード例 #8
0
    def get_translation(self, language: str, translator: 'TranslationService' = None) -> 'Translation':
        if not language:
            raise ValueError('language expected (got {})'.format(language))
        if translator is not None and not isinstance(translator, TranslationService):
            raise expected_or_none('translator', translator, TranslationService)

        if sqlalchemy_inspect(self).transient:
            raise ValueError('Translatable is transient and cannot have translations')

        if language == self.original_language:
            return self._get_dummy_translation()

        translation = Translation.query.filter_by(translatable_id=self.id, language=language).first()

        if translation is None:
            if translator is None:
                raise ValueError('Cannot translate without translator function')

            translation_text = translator.translate(self.original_text, self.original_language, language)

            translation = self.add_translation(language, translation_text, translator.identifier)

        return translation
コード例 #9
0
ファイル: search.py プロジェクト: XZVB12/faraday-pentest
    def create_query(session, model, search_params, _ignore_order_by=False):
        """Builds an SQLAlchemy query instance based on the search parameters
        present in ``search_params``, an instance of :class:`SearchParameters`.

        This method returns a SQLAlchemy query in which all matched instances
        meet the requirements specified in ``search_params``.

        `model` is SQLAlchemy declarative model on which to create a query.

        `search_params` is an instance of :class:`SearchParameters` which
        specify the filters, order, limit, offset, etc. of the query.

        If `_ignore_order_by` is ``True``, no ``order_by`` method will be
        called on the query, regardless of whether the search parameters
        indicate that there should be an ``order_by``. (This is used internally
        by Flask-Restless to work around a limitation in SQLAlchemy.)

        Building the query proceeds in this order:
        1. filtering
        2. ordering
        3. grouping
        3. limiting
        4. offsetting

        Raises one of :exc:`AttributeError`, :exc:`KeyError`, or
        :exc:`TypeError` if there is a problem creating the query. See the
        documentation for :func:`_create_operation` for more information.

        """
        if search_params.group_by:
            select_fields = [func.count()]
            for groupby in search_params.group_by:
                select_fields.append(getattr(model, groupby.field))

            query = session.query(*select_fields)
        else:
            query = session.query(model)
        # For the sake of brevity, rename this method.
        create_filt = QueryBuilder._create_filter
        # This function call may raise an exception.
        valid_model_fields = [
            str(algo).split('.')[1] for algo in sqlalchemy_inspect(model).attrs
        ]

        filters = []
        for filt in search_params.filters:
            if not getattr(filt, 'fieldname',
                           False) or filt.fieldname in valid_model_fields:
                try:
                    filters.append(create_filt(model, filt))
                except AttributeError:
                    # Can't create the filter since the model or submodel does not have the attribute (usually mapper)
                    pass

        # Multiple filter criteria at the top level of the provided search
        # parameters are interpreted as a conjunction (AND).
        query = query.filter(*filters)

        # Order the search. If no order field is specified in the search
        # parameters, order by primary key.
        if not _ignore_order_by:
            if search_params.order_by:
                joined_models = set()
                for val in search_params.order_by:
                    field_name = val.field
                    if '__' in field_name:
                        field_name, field_name_in_relation = \
                            field_name.split('__')
                        relation = getattr(model, field_name)
                        relation_model = relation.mapper.class_
                        field = getattr(relation_model, field_name_in_relation)
                        direction = getattr(field, val.direction)
                        if relation_model not in joined_models:
                            query = query.join(relation_model)
                        joined_models.add(relation_model)
                        query = query.order_by(direction())
                    else:
                        field = getattr(model, val.field)
                        direction = getattr(field, val.direction)
                        query = query.order_by(direction())
            else:
                if not search_params.group_by:
                    pks = primary_key_names(model)
                    pk_order = (getattr(model, field).asc() for field in pks)
                    query = query.order_by(*pk_order)

        # Group the query.
        if search_params.group_by:
            for groupby in search_params.group_by:
                field = getattr(model, groupby.field)
                query = query.group_by(field)

        # Apply limit and offset to the query.
        if search_params.limit:
            query = query.limit(search_params.limit)
        if search_params.offset:
            query = query.offset(search_params.offset)

        return query
コード例 #10
0
def get_or_create_in_transaction(tsession, model, values, missing_columns = [], variable_columns = [], updatable_columns = [], only_use_supplied_columns = False, read_only = False):
    '''
    Uses the SQLAlchemy model to retrieve an existing record based on the supplied field values or, if there is no
    existing record, to create a new database record.

    :param tsession: An SQLAlchemy transactioned session
    :param model: The name of the SQLAlchemy class representing the table
    :param values: A dict of values which will be used to populate the fields of the model
    :param missing_columns: Elements of missing_columns are expected to be fields in the model but are left blank regardless of whether they exist in values. This is useful for auto_increment fields.
    :param updatable_columns: If these are specified, they are treated as missing columns in the record matching and if a record is found, these fields will be updated
    :param variable_columns: If these are specified, they are treated as missing columns in the record matching but are not updated. A good use of these are for datetime fields which default to the current datetime
    :param read_only: If this is set then we query the database and return an instance if one exists but we do not create a new record.
    :return:

    Note: This function is a convenience function and is NOT efficient. The "tsession.query(model).filter_by(**pruned_values)"
          call is only (sometimes) efficient if an index exists on the keys of pruned_values. If any of the fields of pruned_values are
          large (even if otherwise deferred/loaded lazily) then you will incur a performance hit on lookup. You may need
          to reconsider any calls to this function in inner loops of your code.'''


    values = copy.deepcopy(values) # todo: this does not seem to be necessary since we do not seem to be writing

    fieldnames = [c.name for c in list(sqlalchemy_inspect(model).columns)]
    for c in missing_columns:
        fieldnames.remove(c)
    for c in updatable_columns:
        fieldnames.remove(c)
    for c in variable_columns:
        if c in fieldnames:
            fieldnames.remove(c)

    if only_use_supplied_columns:
        fieldnames = sorted(set(fieldnames).intersection(set(values.keys())))
    else:
        unexpected_fields = set(values.keys()).difference(set(fieldnames)).difference(set(variable_columns)).difference(set(updatable_columns))
        if unexpected_fields:
            raise Exception("The fields '{0}' were passed but not found in the schema for table {1}.".format("', '".join(sorted(unexpected_fields)), model.__dict__['__tablename__']))

    pruned_values = {}
    for k in set(values.keys()).intersection(set(fieldnames)):
        v = values[k]
        pruned_values[k] = v

    instance = tsession.query(model).filter_by(**pruned_values)
    if instance.count() > 1:
        raise Exception('Multiple records were found with the search criteria.')
    instance = instance.first()

    if instance:
        if read_only == False:
            for c in updatable_columns:
                setattr(instance, c, values[c])
            tsession.flush()
        return instance
    else:
        if read_only == False:
            if sorted(pruned_values.keys()) != sorted(fieldnames):
                # When adding new records, we require that all necessary fields are present
                raise Exception('Some required fields are missing: {0}. Either supply these fields or add them to the missing_columns list.'.format(set(fieldnames).difference(pruned_values.keys())))
            instance = model(**pruned_values)
            tsession.add(instance)
            tsession.flush()
            return instance
        return None
コード例 #11
0
 def object_session(self):
     return sqlalchemy_inspect(self).session
コード例 #12
0
    def create_query(session, model, search_params, _ignore_order_by=False):
        """Builds an SQLAlchemy query instance based on the search parameters
        present in ``search_params``, an instance of :class:`SearchParameters`.

        This method returns a SQLAlchemy query in which all matched instances
        meet the requirements specified in ``search_params``.

        `model` is SQLAlchemy declarative model on which to create a query.

        `search_params` is an instance of :class:`SearchParameters` which
        specify the filters, order, limit, offset, etc. of the query.

        If `_ignore_order_by` is ``True``, no ``order_by`` method will be
        called on the query, regardless of whether the search parameters
        indicate that there should be an ``order_by``. (This is used internally
        by Flask-Restless to work around a limitation in SQLAlchemy.)

        Building the query proceeds in this order:
        1. filtering
        2. ordering
        3. grouping
        3. limiting
        4. offsetting

        Raises one of :exc:`AttributeError`, :exc:`KeyError`, or
        :exc:`TypeError` if there is a problem creating the query. See the
        documentation for :func:`_create_operation` for more information.

        """
        # TODO: Esto no se puede hacer abajo con el group by?
        joined_models = set()
        query = session.query(model)

        if search_params.group_by:
            select_fields = [func.count()]
            for groupby in search_params.group_by:
                field_name = groupby.field
                if '__' in field_name:
                    field_name, field_name_in_relation = field_name.split('__')
                    relation = getattr(model, field_name)
                    relation_model = relation.mapper.class_
                    field = getattr(relation_model, field_name_in_relation)
                    if relation_model not in joined_models:
                        if relation_model == User:
                            query = query.join(
                                relation_model,
                                model.creator_id == relation_model.id)
                        else:
                            query = query.join(relation_model)
                    joined_models.add(relation_model)
                    select_fields.append(field)
                else:
                    select_fields.append(getattr(model, groupby.field))
                query = query.with_entities(*select_fields)

        # This function call may raise an exception.
        valid_model_fields = []
        for orm_descriptor in sqlalchemy_inspect(model).all_orm_descriptors:
            if isinstance(orm_descriptor, InstrumentedAttribute):
                valid_model_fields.append(str(orm_descriptor).split('.')[1])
            if isinstance(orm_descriptor, hybrid_property):
                valid_model_fields.append(orm_descriptor.__name__)
        valid_model_fields += [
            str(algo).split('.')[1]
            for algo in sqlalchemy_inspect(model).relationships
        ]

        filters_generator = map(  # pylint: disable=W1636
            QueryBuilder.create_filters_func(model, valid_model_fields),
            search_params.filters)

        filters = [filt for filt in filters_generator if filt is not None]

        # Multiple filter criteria at the top level of the provided search
        # parameters are interpreted as a conjunction (AND).
        query = query.filter(*filters)

        # Order the search. If no order field is specified in the search
        # parameters, order by primary key.
        if not _ignore_order_by:
            if search_params.order_by:
                for val in search_params.order_by:
                    field_name = val.field
                    if '__' in field_name:
                        field_name, field_name_in_relation = field_name.split(
                            '__')
                        relation = getattr(model, field_name)
                        relation_model = relation.mapper.class_
                        field = getattr(relation_model, field_name_in_relation)
                        direction = getattr(field, val.direction)
                        if relation_model not in joined_models:
                            query = query.join(relation_model, isouter=True)
                        joined_models.add(relation_model)
                        query = query.order_by(direction())
                    else:
                        field = getattr(model, val.field)
                        direction = getattr(field, val.direction)
                        query = query.order_by(direction())
            else:
                if not search_params.group_by:
                    pks = primary_key_names(model)
                    pk_order = (getattr(model, field).asc() for field in pks)
                    query = query.order_by(*pk_order)

        # Group the query.
        if search_params.group_by:
            for groupby in search_params.group_by:
                field_name = groupby.field
                if '__' in field_name:
                    field_name, field_name_in_relation = field_name.split('__')
                    relation = getattr(model, field_name)
                    relation_model = relation.mapper.class_
                    field = getattr(relation_model, field_name_in_relation)
                else:
                    field = getattr(model, groupby.field)
                query = query.group_by(field)

        # Apply limit and offset to the query.
        if search_params.limit:
            query = query.limit(search_params.limit)
        if search_params.offset:
            query = query.offset(search_params.offset)

        return query