예제 #1
0
    def __init__(self, db_connection, create_tables=False):

        self.db = SQLAlchemyDB(
            db_connection, Base, session_options={
                'autocommit': False, 'autoflush': True,
            })
        if create_tables:
            self.db.create_all_tables()
예제 #2
0
 def __init__(self, db_connection, create_tables=False, secret=None):
     self.db = SQLAlchemyDB(db_connection,
                            Base,
                            session_options={
                                'autocommit': False,
                                'autoflush': True
                            })
     if create_tables:
         self.db.create_all_tables()
     if not secret:
         raise ValueError('Secret is not defined for %s.%s' %
                          (self.__module__, self.__class__.__name__))
     self.secret = secret
예제 #3
0
    def __init__(
            self,
            db_connection,
            create_tables=False,
            secret=None,
            token_ttl_secs=None,
            **engine_parameters):

        self.db = SQLAlchemyDB(
            db_connection, Base, session_options={
                'autocommit': False, 'autoflush': True},
            **engine_parameters)
        if create_tables:
            self.db.create_all_tables()
        if not secret:
            raise ValueError('Secret is not defined for %s.%s' % (
                self.__module__, self.__class__.__name__))
        self.secret = secret
        self.token_ttl_secs = token_ttl_secs or 60 * 60  # 60min
예제 #4
0
파일: api.py 프로젝트: xze/OpenTAXII
    def __init__(self, db_connection, create_tables=False, secret=None):

        self.db = SQLAlchemyDB(
            db_connection, Base, session_options={
                'autocommit': False, 'autoflush': True,
            })
        if create_tables:
            self.db.create_all_tables()

        if not secret:
            raise ValueError('Secret is not defined for %s.%s' % (
                self.__module__, self.__class__.__name__))

        self.secret = secret
예제 #5
0
class SQLDatabaseAPI(OpenTAXIIPersistenceAPI):
    """SQL database implementation of OpenTAXII Persistence API.

    Implementation will work with any DB supported by SQLAlchemy package.

    Note: this implementation ignores ``context.account`` and does not have
    any access rules.

    :param str db_connection: a string that indicates database dialect and
                          connection arguments that will be passed directly
                          to :func:`~sqlalchemy.engine.create_engine` method.

    :param bool create_tables=False: if True, tables will be created in the DB.
    """

    def __init__(self, db_connection, create_tables=False):

        self.db = SQLAlchemyDB(
            db_connection, Base, session_options={
                'autocommit': False, 'autoflush': True,
            })
        if create_tables:
            self.db.create_all_tables()

    def init_app(self, app):
        self.db.init_app(app)

    def get_services(self, collection_id=None):
        if collection_id:
            collection = DataCollection.query.get(collection_id)
            services = collection.services
        else:
            services = Service.query.all()
        return [conv.to_service_entity(s) for s in services]

    def get_service(self, service_id):
        return conv.to_service_entity(Service.get(service_id))

    def update_service(self, obj):

        service = Service.query.get(obj.id) if obj.id else None
        if service:
            service.type = obj.type
            service.properties = obj.properties
        else:
            service = Service(
                id=obj.id, type=obj.type,
                properties=obj.properties)

        service = self.db.session.add(service)
        self.db.session.commit()

        return conv.to_service_entity(service)

    def create_service(self, entity):
        return self.update_service(entity)

    def get_collections(self, service_id):
        service = Service.query.get(service_id)
        return [
            conv.to_collection_entity(c) for c in service.collections]

    def get_collection(self, name, service_id):

        collection = (
            DataCollection.query.join(Service.collections)
                                .filter(Service.id == service_id)
                                .filter(DataCollection.name == name)).first()

        if collection:
            return conv.to_collection_entity(collection)

    def _get_content_query(self, collection_id=None, start_time=None,
                           end_time=None, bindings=None, count=False):
        if count:
            query = self.db.session.query(func.count(ContentBlock.id))
        else:
            query = (ContentBlock
                     .query.order_by(ContentBlock.timestamp_label.asc()))

        if collection_id:
            query = (query.join(ContentBlock.collections)
                          .filter(DataCollection.id == collection_id))

        if start_time:
            query = query.filter(
                ContentBlock.timestamp_label > start_time)

        if end_time:
            query = query.filter(ContentBlock.timestamp_label <= end_time)

        if bindings:
            criteria = []
            for binding in bindings:
                if binding.subtypes:
                    criterion = and_(
                        ContentBlock.binding_id == binding.binding,
                        ContentBlock.binding_subtype.in_(binding.subtypes)
                    )
                else:
                    criterion = ContentBlock.binding_id == binding.binding
                criteria.append(criterion)

            query = query.filter(or_(*criteria))

        return query

    def get_content_blocks_count(self, collection_id=None, start_time=None,
                                 end_time=None, bindings=None):

        query = self._get_content_query(
            collection_id=collection_id,
            start_time=start_time,
            end_time=end_time,
            bindings=bindings,
            count=True)

        return query.scalar()

    def get_content_blocks(self, collection_id=None, start_time=None,
                           end_time=None, bindings=None, offset=0, limit=None):

        query = self._get_content_query(
            collection_id=collection_id,
            start_time=start_time,
            end_time=end_time,
            bindings=bindings)

        query = query.offset(offset)
        if limit:
            query = query.limit(limit)

        return [
            conv.to_block_entity(block)
            for block in query.yield_per(YIELD_PER_SIZE)]

    def create_collection(self, entity):

        _bindings = conv.serialize_content_bindings(entity.supported_content)

        collection = DataCollection(
            name=entity.name,
            type=entity.type,
            description=entity.description,
            available=entity.available,
            accept_all_content=entity.accept_all_content,
            bindings=_bindings
        )

        self.db.session.add(collection)
        self.db.session.commit()

        return conv.to_collection_entity(collection)

    def attach_collection_to_services(self, collection_id, service_ids):

        collection = DataCollection.query.get(collection_id)

        if not collection:
            raise ValueError("Collection with id {} does not exist"
                             .format(collection_id))

        services = Service.query.filter(Service.id.in_(service_ids))

        collection.services.extend(services)

        self.db.session.add(collection)
        self.db.session.commit()

        log.debug("collection.attached", collection=collection.id,
                  collection_name=collection.name, services=service_ids)

    def create_inbox_message(self, entity):

        if entity.destination_collections:
            names = json.dumps(entity.destination_collections)
        else:
            names = None

        begin = entity.exclusive_begin_timestamp_label
        end = entity.inclusive_end_timestamp_label

        message = InboxMessage(
            original_message=entity.original_message,
            content_block_count=entity.content_block_count,
            destination_collections=names,

            service_id=entity.service_id,

            result_id=entity.result_id,
            record_count=entity.record_count,
            partial_count=entity.partial_count,

            subscription_collection_name=entity.subscription_collection_name,
            subscription_id=entity.subscription_id,

            exclusive_begin_timestamp_label=begin,
            inclusive_end_timestamp_label=end
        )

        self.db.session.add(message)
        self.db.session.commit()

        return conv.to_inbox_message_entity(message)

    def create_content_block(self, entity, collection_ids=None,
                             service_id=None):

        if entity.content_binding:
            binding = entity.content_binding.binding
            subtype = (entity.content_binding.subtypes[0]
                       if entity.content_binding.subtypes else None)
        else:
            binding = None
            subtype = None

        content = ContentBlock(
            timestamp_label=entity.timestamp_label,
            inbox_message_id=entity.inbox_message_id,
            content=entity.content,
            binding_id=binding,
            binding_subtype=subtype
        )

        self.db.session.add(content)
        self.db.session.commit()

        if collection_ids:
            self._attach_content_to_collections(content, collection_ids)

        return conv.to_block_entity(content)

    def _attach_content_to_collections(self, content_block, collection_ids):

        if not collection_ids:
            return

        criteria = DataCollection.id.in_(collection_ids)
        new_collections = DataCollection.query.filter(criteria)

        content_block.collections.extend(new_collections)

        self.db.session.add(content_block)

        log.debug("Content block added to collections",
                  content_block=content_block.id,
                  collections=new_collections.count())

        self.db.session.commit()

    def create_result_set(self, entity):

        _bindings = conv.serialize_content_bindings(entity.content_bindings)

        result_set = ResultSet(
            id=entity.id,
            collection_id=entity.collection_id,
            bindings=_bindings,
            begin_time=entity.timeframe[0],
            end_time=entity.timeframe[1]
        )

        self.db.session.add(result_set)
        self.db.session.commit()

        return conv.to_result_set_entity(result_set)

    def get_result_set(self, result_set_id):
        result_set = ResultSet.query.get(result_set_id)
        return conv.to_result_set_entity(result_set)

    def get_subscription(self, subscription_id):
        s = Subscription.query.get(subscription_id)
        return conv.to_subscription_entity(s)

    def get_subscriptions(self, service_id):
        service = Service.query.get(service_id)
        return [
            conv.to_subscription_entity(s) for s in service.subscriptions]

    def update_subscription(self, entity):

        if entity.params:
            params = dict(
                response_type=entity.params.response_type,
                content_bindings=conv.serialize_content_bindings(
                    entity.params.content_bindings)
            )
        else:
            params = {}

        subscription = (
            Subscription.query.get(entity.subscription_id)
            if entity.subscription_id else None)

        if subscription:
            subscription.collection_id = entity.collection_id
            subscription.params = json.dumps(params)
            subscription.status = entity.status
            subscription.service_id = entity.service_id
        else:
            subscription = Subscription(
                id=entity.subscription_id,
                collection_id=entity.collection_id,
                params=json.dumps(params),
                status=entity.status,
                service_id=entity.service_id
            )

        self.db.session.add(subscription)
        self.db.session.commit()

        log.debug("subscription.updated",
                  subscription=subscription.id,
                  collection=subscription.collection_id,
                  status=subscription.status)

        return conv.to_subscription_entity(subscription)

    def create_subscription(self, entity):
        return self.update_subscription(entity)

    def delete_content_blocks(self, collection_name, start_time,
                              end_time=None):

        collection = DataCollection.query.filter_by(name=collection_name).one()

        if not collection:
            raise ValueError("Collection with name '{}' does not exist"
                             .format(collection_name))

        content_blocks_query = (
            self.db.session.query(ContentBlock.id)
                           .join(DataCollection.content_blocks)
                           .filter(DataCollection.id == collection.id)
                           .filter(ContentBlock.timestamp_label > start_time))

        if end_time:
            content_blocks_query = content_blocks_query.filter(
                ContentBlock.timestamp_label <= end_time)

        counter = (
            ContentBlock.query.filter(
                ContentBlock.id.in_(content_blocks_query.subquery()))
            .delete(synchronize_session=False))

        collection.volume = (
            self.db.session
            .query(func.count(ContentBlock.id))
            .join(ContentBlock.collections)
            .filter(DataCollection.id == collection.id)).scalar()

        self.db.session.commit()

        return counter
예제 #6
0
파일: api.py 프로젝트: xze/OpenTAXII
class SQLDatabaseAPI(OpenTAXIIAuthAPI):
    """Naive SQL database implementation of OpenTAXII Auth API.

    Implementation will work with any DB supported by SQLAlchemy package.

    :param str db_connection: a string that indicates database dialect and
                          connection arguments that will be passed directly
                          to :func:`~sqlalchemy.engine.create_engine` method.

    :param bool create_tables=False: if True, tables will be created in the DB.
    """

    def __init__(self, db_connection, create_tables=False, secret=None):

        self.db = SQLAlchemyDB(
            db_connection, Base, session_options={
                'autocommit': False, 'autoflush': True,
            })
        if create_tables:
            self.db.create_all_tables()

        if not secret:
            raise ValueError('Secret is not defined for %s.%s' % (
                self.__module__, self.__class__.__name__))

        self.secret = secret

    def init_app(self, app):
        self.db.init_app(app)

    def authenticate(self, username, password):
        try:
            account = Account.query.filter_by(username=username).one()
        except exc.NoResultFound:
            return

        if not account.is_password_valid(password):
            return

        return self._generate_token(account.id, ttl=TOKEN_TTL)

    def get_account(self, token):

        account_id = self._get_account_id(token)

        if not account_id:
            return

        account = Account.query.get(account_id)

        if not account:
            return

        return AccountEntity(id=account.id, username=account.username)

    def create_account(self, username, password):

        account = Account(username=username)
        account.set_password(password)

        self.db.session.add(account)
        self.db.session.commit()

        return AccountEntity(id=account.id, username=username)

    def _generate_token(self, account_id, ttl=TOKEN_TTL):
        exp = datetime.utcnow() + timedelta(minutes=ttl)
        return jwt.encode(
            {'account_id': account_id, 'exp': exp},
            self.secret)

    def _get_account_id(self, token):
        try:
            payload = jwt.decode(token, self.secret)
        except jwt.ExpiredSignatureError:
            log.warning('Invalid token used', token=token)
            return
        except jwt.DecodeError:
            log.warning('Can not decode a token', token=token)
            return

        return payload.get('account_id')
예제 #7
0
class SQLDatabaseAPI(OpenTAXIIAuthAPI):
    """Naive SQL database implementation of OpenTAXII Auth API.

    Implementation will work with any DB supported by SQLAlchemy package.

    :param str db_connection: a string that indicates database dialect and
                          connection arguments that will be passed directly
                          to :func:`~sqlalchemy.engine.create_engine` method.

    :param bool create_tables=False: if True, tables will be created in the DB.
    """
    def __init__(self, db_connection, create_tables=False, secret=None):
        self.db = SQLAlchemyDB(db_connection,
                               Base,
                               session_options={
                                   'autocommit': False,
                                   'autoflush': True
                               })
        if create_tables:
            self.db.create_all_tables()
        if not secret:
            raise ValueError('Secret is not defined for %s.%s' %
                             (self.__module__, self.__class__.__name__))
        self.secret = secret

    def init_app(self, app):
        self.db.init_app(app)

    def authenticate(self, username, password):
        try:
            account = Account.query.filter_by(username=username).one()
        except exc.NoResultFound:
            return
        if not account.is_password_valid(password):
            return
        return self._generate_token(account.id, ttl=TOKEN_TTL)

    def get_account(self, token):
        account_id = self._get_account_id(token)
        if not account_id:
            return
        account = Account.query.get(account_id)
        if not account:
            return
        return account_to_account_entity(account)

    def delete_account(self, username):
        account = Account.query.filter_by(username=username).one_or_none()
        if account:
            self.db.session.delete(account)
            self.db.session.commit()

    def get_accounts(self):
        return [
            account_to_account_entity(account)
            for account in Account.query.all()
        ]

    def update_account(self, obj, password):
        account = Account.query.filter_by(username=obj.username).one_or_none()
        if not account:
            account = Account(username=obj.username)
            self.db.session.add(account)
        account.set_password(password)
        account.permissions = obj.permissions
        account.is_admin = obj.is_admin
        self.db.session.commit()
        return account_to_account_entity(account)

    def _generate_token(self, account_id, ttl=TOKEN_TTL):
        exp = datetime.utcnow() + timedelta(minutes=ttl)
        return jwt.encode({'account_id': account_id, 'exp': exp}, self.secret)

    def _get_account_id(self, token):
        try:
            payload = jwt.decode(token, self.secret)
        except jwt.ExpiredSignatureError:
            log.warning('Invalid token used', token=token)
            return
        except jwt.DecodeError:
            log.warning('Can not decode a token', token=token)
            return

        return payload.get('account_id')
예제 #8
0
파일: api.py 프로젝트: EclecticIQ/OpenTAXII
class SQLDatabaseAPI(OpenTAXIIAuthAPI):
    """Naive SQL database implementation of OpenTAXII Auth API.

    Implementation will work with any DB supported by SQLAlchemy package.

    :param str db_connection: a string that indicates database dialect and
                          connection arguments that will be passed directly
                          to :func:`~sqlalchemy.engine.create_engine` method.
    :param bool create_tables=False: if True, tables will be created in the DB.
    :param str secret: secret string used for token generation
    :param int token_ttl_secs: TTL for JWT token, in seconds.
    :param engine_parameters=None: if defined, these arguments would be passed to sqlalchemy.create_engine
    """
    def __init__(
            self,
            db_connection,
            create_tables=False,
            secret=None,
            token_ttl_secs=None,
            **engine_parameters):

        self.db = SQLAlchemyDB(
            db_connection, Base, session_options={
                'autocommit': False, 'autoflush': True},
            **engine_parameters)
        if create_tables:
            self.db.create_all_tables()
        if not secret:
            raise ValueError('Secret is not defined for %s.%s' % (
                self.__module__, self.__class__.__name__))
        self.secret = secret
        self.token_ttl_secs = token_ttl_secs or 60 * 60  # 60min

    def init_app(self, app):
        self.db.init_app(app)

    def authenticate(self, username, password):
        try:
            account = Account.query.filter_by(username=username).one()
        except exc.NoResultFound:
            return
        if not account.is_password_valid(password):
            return
        return self._generate_token(account.id, ttl=self.token_ttl_secs)

    def get_account(self, token):
        account_id = self._get_account_id(token)
        if not account_id:
            return
        account = Account.query.get(account_id)
        if not account:
            return
        return account_to_account_entity(account)

    def delete_account(self, username):
        account = Account.query.filter_by(username=username).one_or_none()
        if account:
            self.db.session.delete(account)
            self.db.session.commit()

    def get_accounts(self):
        return [
            account_to_account_entity(account)
            for account in Account.query.all()]

    def update_account(self, obj, password):
        account = Account.query.filter_by(username=obj.username).one_or_none()
        if not account:
            account = Account(username=obj.username)
            self.db.session.add(account)
        account.set_password(password)
        account.permissions = obj.permissions
        account.is_admin = obj.is_admin
        self.db.session.commit()
        return account_to_account_entity(account)

    def _generate_token(self, account_id, ttl=None):
        ttl = ttl or self.token_ttl_secs
        exp = datetime.utcnow() + timedelta(minutes=ttl)
        return jwt.encode(
            {'account_id': account_id, 'exp': exp},
            self.secret)

    def _get_account_id(self, token):
        try:
            payload = jwt.decode(token, self.secret)
        except jwt.ExpiredSignatureError:
            log.warning('Invalid token used', token=token)
            return
        except jwt.DecodeError:
            log.warning('Can not decode a token', token=token)
            return

        return payload.get('account_id')