예제 #1
0
파일: view.py 프로젝트: Andrew59-boop/blog
    def create_model(self, form):
        """
            Create model from form.

            :param form:
                Form instance
        """
        try:
            model = self._manager.new_instance()
            # TODO: We need a better way to create model instances and stay compatible with
            # SQLAlchemy __init__() behavior
            state = instance_state(model)
            self._manager.dispatch.init(state, [], {})

            form.populate_obj(model)
            self.session.add(model)
            self._on_model_change(form, model, True)
            self.session.commit()
        except Exception as ex:
            if not self.handle_view_exception(ex):
                flash(
                    gettext('Failed to create record. %(error)s',
                            error=str(ex)), 'error')
                log.exception('Failed to create record.')

            self.session.rollback()

            return False
        else:
            self.after_model_change(form, model, True)

        return model
def _deleted_by_this_query(host):
    # This process of checking for an already deleted host relies
    # on checking the session after it has been updated by the commit()
    # function and marked the deleted hosts as expired.  It is after this
    # change that the host is called by a new query and, if deleted by a
    # different process, triggers the ObjectDeletedError and is not emited.
    return not instance_state(host).expired
예제 #3
0
    def create_model(self, form):
        """
            Create model from form.

            :param form:
                Form instance
        """
        try:
            model = self._manager.new_instance()
            # TODO: We need a better way to create model instances and stay compatible with
            # SQLAlchemy __init__() behavior
            state = instance_state(model)
            self._manager.dispatch.init(state, [], {})

            form.populate_obj(model)
            self.session.add(model)
            self._on_model_change(form, model, True)
            self.session.commit()
        except Exception as ex:
            if not self.handle_view_exception(ex):
                flash(gettext('Failed to create record. %(error)s', error=str(ex)), 'error')
                log.exception('Failed to create record.')

            self.session.rollback()

            return False
        else:
            self.after_model_change(form, model, True)

        return model
예제 #4
0
    def __init__(self, instance: object, *, copy: bool = False):
        """ Make a lightweight snapshot of an instance.

        Be sure to do it before flush(), because flush() will erase all in-memory changes.

        Args:
            instance: The instance to get the historical values for.
            copy: Copy every mutable value.
                Useful for embedded dictionaries, but it a bit more expensive, so disabled by default.
        """
        # Model info
        self.__model_info = sa_model_info(type(instance),
                                          types=AttributeType.ALL)

        # Remember the historical values
        self.__state: InstanceState = instance_state(instance)
        self.__historical = {
            # Merging dictionaries is very cheap
            **self.__state.dict,  # current values
            **self.__state.committed_state,  # overwritten with DB values
        }

        # Make a deep copy to preserve embedded dictionaries
        if copy:
            self.__historical = deepcopy(self.__historical)
예제 #5
0
    def tojson(self, request, instance, in_list=False, exclude=None,
               exclude_related=None, safe=False, **kw):
        instance = self.instance(instance)
        obj = instance.obj
        if instance_state(obj).detached:
            with self.session(request) as session:
                session.add(obj)
                return self.tojson(request, instance, in_list=in_list,
                                   exclude_related=exclude_related, safe=safe,
                                   exclude=exclude, **kw)
        info = self._fields
        exclude = info.exclude(exclude, exclude_urls=True)
        load_only = instance.fields

        fields = {}
        for field in self.fields().values():
            name = field.name
            if name in exclude or (load_only and name not in load_only):
                continue
            try:
                data = self.get_instance_value(instance, name)
                if isinstance(data, date):
                    if isinstance(data, datetime) and not data.tzinfo:
                        data = pytz.utc.localize(data)
                    data = data.isoformat()
                elif isinstance(data, Enum):
                    data = data.name
                elif is_rel_field(field):
                    if exclude_related:
                        continue
                    model = request.app.models.get(field.model)
                    if model:
                        data = self._related_model(request, model, data,
                                                   in_list)
                    else:
                        data = None
                        request.logger.error(
                            'Could not find model %s', field.model)
                else:  # Test Json
                    json.dumps(data)
            except TypeError:
                try:
                    data = str(data)
                except Exception:
                    continue
            except ObjectDeletedError:
                raise ModelNotAvailable from None
            except Exception:
                if not safe:
                    request.logger.exception(
                        'Exception while converting attribute "%s" in model '
                        '%s to JSON', name, self)
                continue
            if data is not None:
                if isinstance(data, list):
                    name = '%s[]' % name
                fields[name] = data
        return self.instance_urls(request, instance, fields)
예제 #6
0
def sa_modified_names(instance: object) -> Set[str]:
    """ Get the set of modified attribute names

    Note: you can only use it before flush(), because then all changes
    are persisted and the information about them is lost.
    """
    # `committed_state` is the dict of the original, unchanged, attribute names.
    # `dict` contains current values, and when those are modified, old values go into `committed_state`.
    # Therefore, `set(committed_state)` is what we want.
    return set(instance_state(instance).committed_state)
예제 #7
0
def sa_set_committed_state(obj: object, **committed_values):
    """ Put values into an SqlAlchemy instance as if they were committed to the DB """
    # Give it some DB identity so that SA thinks it can load something
    state: InstanceState = instance_state(obj)
    state.key = object()

    # Set every attribute in such a way that SA thinkg that's the way it looks in the DB
    for k, v in committed_values.items():
        set_committed_value(obj, k, v)

    return obj
예제 #8
0
 def _get_instance_states_with_unloaded(
         session: Session, mapper: Mapper,
         attr_name: str) -> Iterable[InstanceState]:
     """ Iterate over instances in the `session` which have `attr_name` unloaded """
     for instance in session:
         if isinstance(instance, mapper.class_):
             state: InstanceState = instance_state(instance)
             # Only return instances that:
             # 1. Are persistent in the DB (have a PK)
             # 2. Have this attribute unloaded
             if state.persistent and attr_name in state.unloaded:
                 yield state
예제 #9
0
    def build_new_instance(self):
        """
            Build new instance of a model. Useful to override the Flask-Admin behavior
            when the model has a custom __init__ method.
        """
        model = self._manager.new_instance()

        # TODO: We need a better way to create model instances and stay compatible with
        # SQLAlchemy __init__() behavior
        state = instance_state(model)
        self._manager.dispatch.init(state, [], {})

        return model
예제 #10
0
    def __init__(self, obj: object):
        super().__init__(obj)

        # Make a list of attributes the loading of which would lead to an unwanted DB query
        state: InstanceState = instance_state(self._obj)
        self._loaded = loaded_attribute_names(state)
        self._safe_properties = get_all_safely_loadable_properties(type(obj))

        # Now, because we're going to ignore some of the unloaded attributes, we'll need to set BaseModel.__fields_set__.
        # However, we do not have any BaseModel here. Unfortunately.
        # Therefore, we have to collect those unloaded fields and stash them somewhere.
        # Where? Inside the intance itself: InstanceState.info is a perfect place
        # Then, SALoadedModel will pick it up and set `__fields_set__` for us
        self._excluded = state.info[SALoadedGetterDict] = set()
예제 #11
0
def sa_dependencies(instance: Union[SAInstanceT, Iterable[SAInstanceT]], map: PluckMap, _seen: set = None) -> List[PrimaryKey]:
    """ Automatically collect PrimaryKey dependencies from SqlAlchemy instances

    Usage: when you have an output from a query, use sa_dependencies() on it to get a list of dependencies
    that you can provide to the MatroskaCache.put() function.
    NOTE: it will only pick primary key dependencies! Other dependencies can only be provided manually!

    Args:
        instance: the instance to get dependencies from
        map: inclusion map: {attribute: 1, relatiopnship: {key: 1, ...})
            Use `1` to include an attribute, dict() to include a relationship, `0` to exclude something
            NOTE: `map` must be valid!
    """
    if _seen is None:
        _seen = set()

    # Lists
    if isinstance(instance, (list, set, tuple)):
        return list(itertools.chain.from_iterable(
            sa_dependencies(item, map, _seen) for item in instance
        ))

    # Instances
    state: InstanceState = instance_state(instance)
    relationships: Mapping[str, RelationshipProperty] = state.mapper.relationships

    # Include self
    ret = []
    if instance not in _seen:
        ret.append(PrimaryKey.from_instance(instance))
    _seen.add(instance)

    # If there's anything left to iterate, do it
    if isinstance(map, dict):
        for key, include in map.items():
            # Skip excluded elements
            if not include:
                continue

            # Skip non-relationships
            if key not in relationships:
                # TODO: implement dependencies on individual attributes?
                continue

            # Descend into the relationship
            ret.extend(sa_dependencies(getattr(instance, key), include, _seen))

    return ret
예제 #12
0
def get_history_proxy_for_instance(instance: SAInstanceT,
                                   copy: bool = False) -> SAInstanceT:
    """ Get a permanent InstanceHistoryProxy for an instance.

    Every time this function is called on an instance, even after flush, the very same InstanceHistoryProxy will be returned.
    Be careful with long-living instances: they will remember their original values the whole time.
    """
    state: InstanceState = instance_state(instance)

    # Create a new one
    if InstanceHistoryProxy not in state.info:
        state.info[InstanceHistoryProxy] = InstanceHistoryProxy(instance,
                                                                copy=copy)

    # Done
    return state.info[InstanceHistoryProxy]
예제 #13
0
파일: book.py 프로젝트: e2thenegpii/piecash
 def track_dirty(session, flush_context, instances):
     """
     Record in session._all_changes the objects that have been modified before each flush
     """
     for change, l in {"dirty": session.dirty,
                       "new": session.new,
                       "deleted": session.deleted}.items():
         for obj in l:
             # retrieve the dictionnary of changes for the given obj
             attrs = session._all_changes.setdefault(id(obj), {})
             # add the change of state to the list of state changes
             attrs.setdefault("STATE_CHANGES", []).append(change)
             attrs.setdefault("OBJECT", obj)
             # save old value of attr if not already saved
             # (if a value is changed multiple time, we keep only the first "old value")
             for k, v in instance_state(obj).committed_state.items():
                 if k not in attrs:
                     attrs[k] = v
예제 #14
0
 def track_dirty(session, flush_context, instances):
     """
     Record in session._all_changes the objects that have been modified before each flush
     """
     for change, l in {"dirty": session.dirty,
                       "new": session.new,
                       "deleted": session.deleted}.items():
         for obj in l:
             # retrieve the dictionnary of changes for the given obj
             attrs = session._all_changes.setdefault(id(obj), {})
             # add the change of state to the list of state changes
             attrs.setdefault("STATE_CHANGES", []).append(change)
             attrs.setdefault("OBJECT", obj)
             # save old value of attr if not already saved
             # (if a value is changed multiple time, we keep only the first "old value")
             for k, v in instance_state(obj).committed_state.items():
                 if k not in attrs:
                     attrs[k] = v
예제 #15
0
def prevent_model_recursion(obj: SAModelT, marker_key: Hashable) -> Optional[SAModelT]:
    """ Mark an instance as "being processed at the moment" and return it. In case of recursion, return None """
    # Prepare a place to mark the instance as "being processed"
    state: InstanceState = instance_state(obj)
    marker_key = marker_key

    # Is already being parsed? (recursion)
    if marker_key in state.info:
        yield None
        return

    # Mark the instance as "being processed at the moment"
    state.info[marker_key] = True

    # Parse it
    try:
        yield obj
    finally:
        # Unmark it
        del state.info[marker_key]
예제 #16
0
def delete_by_id(host_id_list):
    payload_tracker = get_payload_tracker(
        account=current_identity.account_number,
        payload_id=threadctx.request_id)

    with PayloadTrackerContext(payload_tracker,
                               received_status_message="delete operation"):

        query = _get_host_list_by_id_list(current_identity.account_number,
                                          host_id_list)

        hosts_to_delete = query.all()

        if not hosts_to_delete:
            return flask.abort(status.HTTP_404_NOT_FOUND)

        with metrics.delete_host_processing_time.time():
            query.delete(synchronize_session="fetch")
        db.session.commit()

        metrics.delete_host_count.inc(len(hosts_to_delete))

        # This process of checking for an already deleted host relies
        # on checking the session after it has been updated by the commit()
        # function and marked the deleted hosts as expired.  It is after this
        # change that the host is called by a new query and, if deleted by a
        # different process, triggers the ObjectDeletedError and is not emited.
        for deleted_host in hosts_to_delete:
            # Prevents ObjectDeletedError from being raised.
            if instance_state(deleted_host).expired:
                # Can’t log the Host ID. Accessing an attribute raises ObjectDeletedError.
                logger.info("Host already deleted. Delete event not emitted.")
            else:
                with PayloadTrackerProcessingContext(
                        payload_tracker,
                        processing_status_message="deleted host"
                ) as payload_tracker_processing_ctx:
                    logger.debug("Deleted host: %s", deleted_host)
                    emit_event(events.delete(deleted_host))
                    payload_tracker_processing_ctx.inventory_id = deleted_host.id
예제 #17
0
def row2dict(row):
    """将对象(一般为orm row)转换为dict"""
    record = {}
    # 清除掉过期状态,强制的跳过state._load_expired(state, passive)
    # 如果有字段确实需要而没有的,要么设置default值,要么使用refresh从数据库拿到server_default值
    state = instance_state(row)
    state.expired_attributes.clear()
    attributes, cls = deepcopy(row.__dict__), row.__class__
    for c in dir(row):
        if hasattr(cls, c):
            a = getattr(cls, c)
            # hybrid_property
            if isinstance(a, QueryableAttribute) and not isinstance(
                    a, InstrumentedAttribute):
                attributes[c] = 1  # 这里只需要有attribute name就可以了

    for c in attributes:
        if not c.startswith('_') and 'metadata' != c:
            try:
                v = row.__getattribute__(c)
            except KeyError as e:  # https://github.com/zzzeek/sqlalchemy/blob/master/lib/sqlalchemy/orm/attributes.py#L579 这个函数可能会raise KeyError出来
                logging.exception(e)
                v = datetime.now() if c in ['created', 'modified'] else None
            if isinstance(v, Base):
                v = row2dict(v)
            if isinstance(v, Decimal):
                v = int(v)
            # 特殊处理一下生日,以及开始时间结束时间
            if c in ['start', 'end'
                     ] and row.__tablename__ in ['work', 'education']:
                v = v.strftime('%Y.%m')
            if c in ['birthday'] and row.__tablename__ in ['user']:
                v = v.strftime('%Y.%m.%d')
            if isinstance(v, datetime):
                v = v.strftime('%Y.%m.%d %H:%M:%S')
            if isinstance(v, InstrumentedList):
                v = list(map(lambda i: row2dict(i), v))
            record[c] = v

    return record
예제 #18
0
    def validate(self):
        old = instance_state(self).committed_state

        # check all accounts related to the splits of the transaction are not placeholder(=frozen)
        for sp in self.splits:
            if sp.account.placeholder != 0:
                raise GncValidationError("Account '{}' used in the transaction is a placeholder".format(sp.account))

        # check same currency
        if "currency" in old and old["currency"] is not None:
            raise GncValidationError("You cannot change the currency of a transaction once it has been set")

        # validate the splits
        if "splits" in old:
            value_imbalance, quantity_imbalances = self.calculate_imbalances()
            if value_imbalance:
                # raise exception instead of creating an imbalance entry as probably an error
                # (in the gnucash GUI, another decision taken because need of "save unfinished transaction")
                raise GncImbalanceError("The transaction {} is not balanced on its value".format(self))

            if any(quantity_imbalances.values()) and self.book.use_trading_accounts:
                self.normalize_trading_accounts()
예제 #19
0
    def tojson(self,
               request,
               instance,
               in_list=False,
               exclude=None,
               exclude_related=None,
               safe=False,
               **kw):
        instance = self.instance(instance)
        obj = instance.obj
        if instance_state(obj).detached:
            with self.session(request) as session:
                session.add(obj)
                return self.tojson(request,
                                   instance,
                                   in_list=in_list,
                                   exclude_related=exclude_related,
                                   safe=safe,
                                   exclude=exclude,
                                   **kw)
        info = self._fields
        exclude = info.exclude(exclude, exclude_urls=True)
        load_only = instance.fields

        fields = {}
        for field in self.fields().values():
            name = field.name
            if name in exclude or (load_only and name not in load_only):
                continue
            try:
                data = self.get_instance_value(instance, name)
                if isinstance(data, date):
                    if isinstance(data, datetime) and not data.tzinfo:
                        data = pytz.utc.localize(data)
                    data = data.isoformat()
                elif isinstance(data, Enum):
                    data = data.name
                elif is_rel_field(field):
                    if exclude_related:
                        continue
                    model = request.app.models.get(field.model)
                    if model:
                        data = self._related_model(request, model, data,
                                                   in_list)
                    else:
                        data = None
                        request.logger.error('Could not find model %s',
                                             field.model)
                else:  # Test Json
                    json.dumps(data)
            except TypeError:
                try:
                    data = str(data)
                except Exception:
                    continue
            except ObjectDeletedError:
                raise ModelNotAvailable from None
            except Exception:
                if not safe:
                    request.logger.exception(
                        'Exception while converting attribute "%s" in model '
                        '%s to JSON', name, self)
                continue
            if data is not None:
                if isinstance(data, list):
                    name = '%s[]' % name
                fields[name] = data
        return self.instance_urls(request, instance, fields)
예제 #20
0
 def build_new_instance(self):
     model = self._manager.new_instance()
     model.__init__()  # <-- Call SAFRSBase.__init__()
     state = instance_state(model)
     self._manager.dispatch.init(state, [], {})
     return model
예제 #21
0
    def get_names_excluded_from(cls, obj: object) -> Set[str]:
        """ Get the list of attribute names that SALoadedGetterDict has excluded

        See SALoadedGetterDict._excluded
        """
        return instance_state(obj).info[SALoadedGetterDict]
예제 #22
0
 def object_beforechange(self):
     return instance_state(self).committed_state
예제 #23
0
 def from_instance(cls, instance: object):
     state: InstanceState = instance_state(instance)
     return cls(state.class_.__name__,
                cls._instance_identity_to_str(state.identity))