Ejemplo n.º 1
0
    def handle_indicators_search(self, token, data, **kwargs):
        t = self.store.tokens.read(token)

        try:
            self._log_search(t, data)

        except TypeError:
            raise InvalidSearch('invalid search')

        except Exception as e:
            logger.error(e)

        try:
            x = self.store.indicators.search(t, data)
        except Exception as e:
            logger.error(e)

            if logger.getEffectiveLevel() == logging.DEBUG:
                import traceback
                traceback.print_exc()

            raise InvalidSearch('invalid search')

        if isinstance(x, GeneratorType):
            x = list(x)

        return x
Ejemplo n.º 2
0
    def _filter_indicator(self, filters, s):

        for k, v in list(filters.items()):
            if k not in VALID_FILTERS:
                del filters[k]

        if not filters.get('indicator'):
            return s

        i = filters.pop('indicator')
        if PYVERSION == 2:
            if isinstance(i, str):
                i = unicode(i)

        try:
            itype = resolve_itype(i)
        except InvalidIndicator as e:
            logger.error(e)
            s = s.join(Message).filter(Indicator.Message.like(
                '%{}%'.format(i)))
            return s

        if itype in ['fqdn', 'email', 'url']:
            s = s.filter(Indicator.indicator == i)
            return s

        if itype == 'ipv4':
            ip = ipaddress.IPv4Network(i)
            mask = ip.prefixlen

            if mask < 8:
                raise InvalidSearch('prefix needs to be >= 8')

            start = str(ip.network_address)
            end = str(ip.broadcast_address)

            logger.debug('{} - {}'.format(start, end))

            s = s.join(Ipv4).filter(Ipv4.ipv4 >= start)
            s = s.filter(Ipv4.ipv4 <= end)
            return s

        if itype == 'ipv6':
            ip = ipaddress.IPv6Network(i)
            mask = ip.prefixlen

            if mask < 32:
                raise InvalidSearch('prefix needs to be >= 32')

            start = str(ip.network_address)
            end = str(ip.broadcast_address)

            logger.debug('{} - {}'.format(start, end))

            s = s.join(Ipv6).filter(Ipv6.ip >= start)
            s = s.filter(Ipv6.ip <= end)
            return s

        raise InvalidIndicator
Ejemplo n.º 3
0
    def handle_indicators_search(self, token, data, **kwargs):
        t = self.store.tokens.read(token)

        if data.get('indicator'):
            # python2
            try:
                if isinstance(data['indicator'], str):
                    data['indicator'] = unicode(data['indicator'])
            except:
                pass

        self._log_search(t, data)

        try:
            x = self.store.indicators.search(t, data)
        except Exception as e:
            logger.error(e)

            if logger.getEffectiveLevel() == logging.DEBUG:
                import traceback
                traceback.print_exc()

            raise InvalidSearch('invalid search')

        return x
    def _recv(self, decode=True):
        mtype, data = Msg().recv(self.socket)

        if not decode:
            return data

        data = json.loads(data)

        if data.get('message') == 'unauthorized':
            raise AuthError()

        if data.get('message') == 'busy':
            raise CIFBusy()

        if data.get('message') == 'invalid search':
            raise InvalidSearch()

        if data.get('status') != 'success':
            raise RuntimeError(data.get('message'))

        try:
            data['data'] = zlib.decompress(data['data'])
        except (zlib.error, TypeError):
            pass

        return data.get('data')
Ejemplo n.º 5
0
    def _get(self, uri, params={}):
        if not uri.startswith('http'):
            uri = self.remote + uri

        resp = self.session.get(uri,
                                params=params,
                                verify=self.verify_ssl,
                                timeout=self.timeout)

        self._check_status(resp, expect=200)

        data = resp.content

        s = (int(resp.headers['Content-Length']) / 1024 / 1024)
        logger.info('processing %.2f megs' % s)

        msgs = json.loads(data.decode('utf-8'))

        if not msgs.get('status') and not msgs.get('message') == 'success':
            raise RuntimeError(msgs)

        if msgs.get('status') and msgs['status'] == 'failure':
            raise InvalidSearch(msgs['message'])

        if isinstance(msgs.get('data'), list):
            for m in msgs['data']:
                if m.get('message'):
                    try:
                        m['message'] = b64decode(m['message'])
                    except Exception as e:
                        pass
        return msgs
Ejemplo n.º 6
0
    def _check_data(self, msgs):
        if isinstance(msgs, list):
            return msgs

        if msgs.get('status', False) not in ['success', 'failure']:
            raise RuntimeError(msgs)

        if msgs.get('status') == 'failure':
            raise InvalidSearch(msgs['message'])

        if msgs['data'] == '{}':
            msgs['data'] = []

        # check to see if it's straight elasticsearch json
        if isinstance(msgs['data'], basestring) and msgs['data'].startswith(
                '{"hits":{"hits":[{"_source":'):
            msgs['data'] = json.loads(msgs['data'])
            msgs['data'] = [r['_source'] for r in msgs['data']['hits']['hits']]

        if not isinstance(msgs['data'], list):
            msgs['data'] = [msgs['data']]

        for m in msgs['data']:
            if not isinstance(m, dict):
                continue

            if not m.get('message'):
                continue

            try:
                m['message'] = b64decode(m['message'])
            except Exception as e:
                pass

        return msgs
Ejemplo n.º 7
0
    def _check_status(self, resp, expect=200):
        if resp.status_code == 400:
            r = json.loads(resp.text)
            raise InvalidSearch(r['message'])

        if resp.status_code == 401:
            raise AuthError('unauthorized')

        if resp.status_code == 404:
            raise NotFound('not found')

        if resp.status_code == 408:
            raise TimeoutError('timeout')

        if resp.status_code == 422:
            msg = json.loads(resp.text)
            raise SubmissionFailed(msg['message'])

        if resp.status_code == 429:
            raise CIFBusy('RateLimit exceeded')

        if resp.status_code in [500, 501, 502, 503, 504]:
            raise CIFBusy('system seems busy..')

        if resp.status_code != expect:
            msg = 'unknown: %s' % resp.content
            raise RuntimeError(msg)
Ejemplo n.º 8
0
    def handle_indicators_search(self, token, data, **kwargs):
        t = self.store.token_read(token)
        if not t:
            raise AuthError('invalid token')

        if PYVERSION == 2:
            if data.get('indicator'):
                if isinstance(data['indicator'], str):
                    data['indicator'] = unicode(data['indicator'])

        try:
            x = self.store.indicators_search(data)
        except Exception as e:
            logger.error(e)
            if logger.getEffectiveLevel() == logging.DEBUG:
                import traceback
                logger.error(traceback.print_exc())
            raise InvalidSearch('invalid search')

        t = self.store.tokens_search({'token': token})
        self._log_search(t, data)

        if isinstance(x, GeneratorType):
            x = list(x)

        for xx in x:
            if xx.get('message'):
                xx['message'] = b64encode(xx['message']).encode('utf-8')

        return x
Ejemplo n.º 9
0
    def _filter_terms(self, filters, s):

        # TODO also you should do for k, v in filters.items():
        # iteritems()?
        for k in filters:
            if k in ['nolog', 'days', 'hours', 'groups', 'limit']:
                continue

            if k == 'reporttime':
                if ',' in filters[k]:
                    start, end = filters[k].split(',')
                    s = s.filter(
                        Indicator.reporttime >= arrow.get(start).datetime)
                    s = s.filter(
                        Indicator.reporttime <= arrow.get(end).datettime)
                else:
                    s = s.filter(
                        Indicator.reporttime >= arrow.get(filters[k]).datetime)

            elif k == 'reporttimeend':
                s = s.filter(Indicator.reporttime <= filters[k])

            elif k == 'tags':
                s = s.outerjoin(Tag).filter(Tag.tag == filters[k])

            elif k == 'confidence':
                if ',' in str(filters[k]):
                    start, end = str(filters[k]).split(',')
                    s = s.filter(Indicator.confidence >= float(start))
                    s = s.filter(Indicator.confidence <= float(end))
                else:
                    s = s.filter(Indicator.confidence >= float(filters[k]))

            elif k == 'itype':
                s = s.filter(Indicator.itype == filters[k])

            elif k == 'provider':
                s = s.filter(Indicator.provider == filters[k])

            elif k == 'asn':
                s = s.filter(Indicator.asn == filters[k])

            elif k == 'asn_desc':
                s = s.filter(Indicator.asn_desc.like('%{}%'.format(
                    filters[k])))

            elif k == 'cc':
                s = s.filter(Indicator.cc == filters[k])

            elif k == 'rdata':
                s = s.filter(Indicator.rdata == filters[k])

            elif k == 'region':
                s = s.filter(Indicator.region == filters[k])

            else:
                raise InvalidSearch('invalid filter: %s' % k)

        return s
Ejemplo n.º 10
0
    def indicators_search(self, token, filters):
        # build filters with elasticsearch-dsl
        # http://elasticsearch-dsl.readthedocs.org/en/latest/search_dsl.html

        limit = filters.get('limit')
        if limit:
            del filters['limit']
        else:
            limit = LIMIT

        nolog = filters.get('nolog')
        if nolog:
            del filters['nolog']

        timeout = TIMEOUT

        s = Indicator.search()
        s = s.params(size=limit, timeout=timeout)
        #s = s.sort('-reporttime')

        q_filters = {}
        for f in VALID_FILTERS:
            if filters.get(f):
                q_filters[f] = filters[f]

        if q_filters.get('indicator'):
            itype = resolve_itype(q_filters['indicator'])

            if itype == 'ipv4':
                ip = ipaddress.IPv4Network(q_filters['indicator'])
                mask = ip.prefixlen
                if mask < 8:
                    raise InvalidSearch(
                        'prefix needs to be greater than or equal to 8')
                start = str(ip.network_address)
                end = str(ip.broadcast_address)

                s = s.filter('range',
                             indicator_ipv4={
                                 'gte': start,
                                 'lte': end
                             })
                del q_filters['indicator']

        for f in q_filters:
            kwargs = {f: q_filters[f]}
            s = s.filter('term', **kwargs)

        try:
            rv = s.execute()
        except elasticsearch.exceptions.RequestError as e:
            self.logger.error(e)
            return []

        try:
            return [x['_source'] for x in rv.hits.hits]
        except KeyError:
            return []
Ejemplo n.º 11
0
    def _filter_indicator(self, q_filters, s):
        if not q_filters.get('indicator'):
            return s

        i = q_filters.pop('indicator')

        try:
            itype = resolve_itype(i)
        except InvalidIndicator:
            s = s.query("match", message=i)
            return s

        if itype in ('email', 'url', 'fqdn'):
            s = s.filter('term', indicator=i)
            return s

        if itype is 'ipv4':
            ip = ipaddress.IPv4Network(i)
            mask = ip.prefixlen
            if mask < 8:
                raise InvalidSearch(
                    'prefix needs to be greater than or equal to 8')

            start = str(ip.network_address)
            end = str(ip.broadcast_address)

            s = s.filter('range', indicator_ipv4={'gte': start, 'lte': end})
            return s

        if itype is 'ipv6':
            ip = ipaddress.IPv6Network(i)
            mask = ip.prefixlen
            if mask < 32:
                raise InvalidSearch(
                    'prefix needs to be greater than or equal to 32')

            start = binascii.b2a_hex(
                socket.inet_pton(socket.AF_INET6,
                                 str(ip.network_address))).decode('utf-8')
            end = binascii.b2a_hex(
                socket.inet_pton(socket.AF_INET6,
                                 str(ip.broadcast_address))).decode('utf-8')

            s = s.filter('range', indicator_ipv6={'gte': start, 'lte': end})
            return s
Ejemplo n.º 12
0
def _filter_ipv4(s, i):
    ip = ipaddress.IPv4Network(i)
    mask = ip.prefixlen
    if mask < 8:
        raise InvalidSearch('prefix needs to be greater than or equal to 8')

    start = str(ip.network_address)
    end = str(ip.broadcast_address)

    s = s.filter('range', indicator_ipv4={'gte': start, 'lte': end})
    return s
Ejemplo n.º 13
0
 def handle_indicators_search(self, token, data):
     if self.store.token_read(token):
         self.logger.debug('searching')
         try:
             x = self.store.indicators_search(token, data)
         except Exception as e:
             self.logger.error(e)
             raise InvalidSearch('invalid search')
         else:
             return x
     else:
         raise AuthError('invalid token')
Ejemplo n.º 14
0
    def handle_stats_search(self, token, data, **kwargs):
        t = self.store.tokens.read(token)

        try:
            x = self.store.indicators.stats_search(t, data)
        except Exception as e:
            logger.error(e)

            if logger.getEffectiveLevel() == logging.DEBUG:
                traceback.print_exc()

            raise InvalidSearch('invalid search')

        return x
Ejemplo n.º 15
0
def _filter_ipv6(s, i):
    ip = ipaddress.IPv6Network(i)
    mask = ip.prefixlen
    if mask < 32:
        raise InvalidSearch('prefix needs to be greater than or equal to 32')

    start = binascii.b2a_hex(
        socket.inet_pton(socket.AF_INET6,
                         str(ip.network_address))).decode('utf-8')
    end = binascii.b2a_hex(
        socket.inet_pton(socket.AF_INET6,
                         str(ip.broadcast_address))).decode('utf-8')

    s = s.filter('range', indicator_ipv6={'gte': start, 'lte': end})
    return s
Ejemplo n.º 16
0
    def handle_indicators_search(self, token, data, **kwargs):
        t = self.store.tokens.read(token)

        if PYVERSION == 2:
            if data.get('indicator'):
                if isinstance(data['indicator'], str):
                    data['indicator'] = unicode(data['indicator'])

        if not data.get('reporttime'):
            if data.get('days'):
                now = arrow.utcnow()
                data['reporttimeend'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))
                now = now.replace(days=-int(data['days']))
                data['reporttime'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))

            if data.get('hours'):
                now = arrow.utcnow()
                data['reporttimeend'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))
                now = now.replace(hours=-int(data['hours']))
                data['reporttime'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))

        s = time.time()

        self._log_search(t, data)

        try:
            x = self.store.indicators.search(t, data)
            logger.debug('done')
        except Exception as e:
            logger.error(e)

            if logger.getEffectiveLevel() == logging.DEBUG:
                logger.error(traceback.print_exc())

            raise InvalidSearch('invalid search')

        logger.debug('%s' % (time.time() - s))

        # for xx in x:
        #     if xx.get('message'):
        #         xx['message'] = b64encode(xx['message']).encode('utf-8')

        return x
Ejemplo n.º 17
0
    def _recv(self, decode=True, close=True):
        mtype, data = Msg().recv(self.socket)
        if close:
            self.socket.close()

        if not decode:
            return data

        data = json.loads(data)

        if data.get('message') == 'unauthorized':
            raise AuthError()

        if data.get('message') == 'busy':
            raise CIFBusy()

        if data.get('message') == 'invalid search':
            raise InvalidSearch()

        if data.get('status') != 'success':
            raise RuntimeError(data.get('message'))

        if data.get('data') is None:
            raise RuntimeError('invalid response')

        if isinstance(data.get('data'), bool):
            return data['data']

        # is this a straight up elasticsearch string?
        if data['data'] == '{}':
            return []

        if isinstance(data['data'], basestring) and data['data'].startswith('{"hits":{"hits":[{"_source":'):
            data['data'] = json.loads(data['data'])
            data['data'] = [r['_source'] for r in data['data']['hits']['hits']]

        try:
            data['data'] = zlib.decompress(data['data'])
        except (zlib.error, TypeError):
            pass

        return data.get('data')
Ejemplo n.º 18
0
    def _filter_terms(self, filters, s):

        # TODO also you should do for k, v in filters.items():
        # iteritems()?
        for k in filters:
            if k in ['nolog', 'days', 'hours', 'groups', 'limit']:
                continue

            if k == 'reporttime':
                if ',' in filters[k]:
                    start, end = filters[k].split(',')
                    s = s.filter(Indicator.reporttime >= start)
                    s = s.filter(Indicator.reporttime <= end)
                else:
                    s = s.filter(Indicator.reporttime >= filters[k])

            elif k == 'reporttimeend':
                s = s.filter(Indicator.reporttime <= filters[k])

            elif k == 'tags':
                s = s.join(Tag).filter(Tag.tag == filters[k])

            elif k == 'confidence':
                if ',' in str(filters[k]):
                    start, end = str(filters[k]).split(',')
                    s = s.filter(Indicator.confidence >= float(start))
                    s = s.filter(Indicator.confidence <= float(end))
                else:
                    s = s.filter(Indicator.confidence >= filters[k])

            elif k == 'itype':
                s = s.filter(Indicator.itype == filters[k])

            elif k == 'provider':
                s = s.filter(Indicator.provider == filters[k])

            else:
                raise InvalidSearch('invalid filter: %s' % k)

        return s
Ejemplo n.º 19
0
def filter_reporttime(s, filter):
    if not filter.get('reporttime'):
        return s

    high = 'now/m'
    # if passed 'days' or 'hours', preferentially use that for ES filtering/caching
    if filter.get('days') or filter.get('hours'):
        if filter.get('hours'):
            lookback_amount = filter.pop('hours')
            lookback_unit = 'h'
        elif filter.get('days'):
            lookback_amount = filter.pop('days')
            lookback_unit = 'd'

        try:
            lookback_amount = int(lookback_amount)
        except Exception as e:
            raise InvalidSearch(
                'Lookback time filter {}{} is not a valid time'.format(
                    lookback_amount, lookback_unit))

        # don't put spaces in relative date math operator query to make it easier to read. ES hates that and will error.
        low = 'now/m-{}{}'.format(lookback_amount, lookback_unit)
    # no relative 'days' or 'hours' params, so fallback to 'reporttime'
    else:
        c = filter.pop('reporttime')
        if PYVERSION == 2:
            if type(c) == unicode:
                c = str(c)

        if isinstance(c, basestring) and ',' in c:
            low, high = c.split(',')
        else:
            low = c

        low = arrow.get(low).datetime

    s = s.filter('range', reporttime={'gte': low, 'lte': high})
    return s
Ejemplo n.º 20
0
def filter_build(s, filters, token=None):
    limit = filters.get('limit')
    if limit and int(limit) > WINDOW_LIMIT:
        raise InvalidSearch(
            'Request limit should be <= server threshold of {} but was set to {}'
            .format(WINDOW_LIMIT, limit))

    q_filters = {}
    for f in VALID_FILTERS:
        if filters.get(f):
            q_filters[f] = filters[f]

    s = filter_provider(s, q_filters)

    s = filter_confidence(s, q_filters)

    s = filter_id(s, q_filters)

    s = filter_rdata(s, q_filters)

    # treat indicator as special, transform into Search
    s = filter_indicator(s, q_filters)

    s = filter_reporttime(s, q_filters)

    # transform all other filters into term=
    s = filter_terms(s, q_filters)

    if q_filters.get('groups'):
        s = filter_groups(s, q_filters)
    else:
        if token and (not token.get('admin') or token.get('admin') == ''):
            s = filter_groups(s, {}, token=token)

    if q_filters.get('tags'):
        s = filter_tags(s, q_filters)

    return s
Ejemplo n.º 21
0
    def handle_indicators_search(self, token, data, **kwargs):

        if PYVERSION == 2:
            if data.get('indicator'):
                if isinstance(data['indicator'], str):
                    data['indicator'] = unicode(data['indicator'])

        # token acl check
        if token.get('acl') and token.get('acl') != ['']:
            if data.get('itype') and data.get('itype') not in token['acl']:
                raise AuthError('unauthorized to access itype {}'.format(
                    data['itype']))

            if not data.get('itype'):
                data['itype'] = token['acl']

        # verify group filter matches token permissions
        if data.get('groups') and (not token.get('admin')
                                   or token.get('admin') == ''):
            if isinstance(data['groups'], basestring):
                q_groups = [g.strip() for g in data['groups'].split(',')]
            elif isinstance(data['groups'], list):
                q_groups = data['groups']

            gg = []
            for g in q_groups:
                if AUTH_ENABLED:
                    if g in token['groups']:
                        gg.append(g)
                else:
                    gg.append(g)

            if gg:
                data['groups'] = gg
            else:
                data['groups'] = '{}'

        if not data.get('reporttime'):
            if data.get('days'):
                now = arrow.utcnow()
                data['reporttimeend'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))
                now = now.shift(days=-int(data['days']))
                data['reporttime'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))

            if data.get('hours'):
                now = arrow.utcnow()
                data['reporttimeend'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))
                now = now.shift(hours=-int(data['hours']))
                data['reporttime'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))

        s = time.time()

        self._log_search(token, data)

        try:
            x = self.store.indicators.search(token, data)
            logger.debug('done')
        except Exception as e:
            logger.error(e)

            if logger.getEffectiveLevel() == logging.DEBUG:
                logger.error(traceback.print_exc())

            raise InvalidSearch(': {}'.format(e))

        logger.debug('%s' % (time.time() - s))

        # for xx in x:
        #     if xx.get('message'):
        #         xx['message'] = b64encode(xx['message']).encode('utf-8')

        return x
Ejemplo n.º 22
0
    def indicators_search(self, filters, limit=500):
        self.logger.debug('running search')

        limit = filters.pop('limit', limit)
        nolog = filters.pop('nolog', False)

        q_filters = {}
        for f in VALID_FILTERS:
            if filters.get(f):
                q_filters[f] = filters[f]

        s = self.handle().query(Indicator)

        filters = q_filters

        sql = []

        if filters.get('indicator'):
            try:
                itype = resolve_itype(filters['indicator'])
                self.logger.debug('itype %s' % itype)
                if itype == 'ipv4':

                    if PYVERSION < 3 and (filters['indicator'], str):
                        filters['indicator'] = filters['indicator'].decode(
                            'utf-8')

                    ip = ipaddress.IPv4Network(filters['indicator'])
                    mask = ip.prefixlen

                    if mask < 8:
                        raise InvalidSearch('prefix needs to be >= 8')

                    start = str(ip.network_address)
                    end = str(ip.broadcast_address)

                    self.logger.debug('{} - {}'.format(start, end))

                    s = s.join(Ipv4).filter(Ipv4.ipv4 >= start)
                    s = s.filter(Ipv4.ipv4 <= end)

                elif itype == 'ipv6':
                    if PYVERSION < 3 and (filters['indicator'], str):
                        filters['indicator'] = filters['indicator'].decode(
                            'utf-8')

                    ip = ipaddress.IPv6Network(filters['indicator'])
                    mask = ip.prefixlen

                    if mask < 32:
                        raise InvalidSearch('prefix needs to be >= 32')

                    start = str(ip.network_address)
                    end = str(ip.broadcast_address)

                    self.logger.debug('{} - {}'.format(start, end))

                    s = s.join(Ipv6).filter(Ipv6.ip >= start)
                    s = s.filter(Ipv6.ip <= end)

                elif itype in ('fqdn', 'email', 'url'):
                    sql.append("indicator = '{}'".format(filters['indicator']))

            except InvalidIndicator as e:
                self.logger.error(e)
                sql.append("message LIKE '%{}%'".format(filters['indicator']))
                s = s.join(Message)

            del filters['indicator']

        for k in filters:
            if k == 'reporttime':
                sql.append("{} >= '{}'".format('reporttime', filters[k]))
            elif k == 'reporttimeend':
                sql.append("{} <= '{}'".format('reporttime', filters[k]))
            elif k == 'tags':
                sql.append("tags.tag == '{}'".format(filters[k]))
            elif k == 'confidence':
                sql.append("{} >= '{}'".format(k, filters[k]))
            else:
                sql.append("{} = '{}'".format(k, filters[k]))

        sql = ' AND '.join(sql)

        if sql:
            self.logger.debug(sql)

        if filters.get('tags'):
            s = s.join(Tag)

        rv = s.order_by(desc(Indicator.reporttime)).filter(
            text(sql)).limit(limit)

        return [self._as_dict(x) for x in rv]
Ejemplo n.º 23
0
    def _get(self, uri, params={}, retry=True):
        if not uri.startswith('http'):
            uri = self.remote + uri

        resp = self.session.get(uri,
                                params=params,
                                verify=self.verify_ssl,
                                timeout=self.timeout)
        n = RETRIES
        try:
            self._check_status(resp, expect=200)
            n = 0
        except Exception as e:
            if resp.status_code == 429 or resp.status_code in [
                    500, 501, 502, 503, 504
            ]:
                logger.error(e)
            else:
                raise e

        while n != 0:
            logger.warning(
                'setting random retry interval to spread out the load')
            logger.warning('retrying in %.00fs' % RETRIES_DELAY)
            sleep(RETRIES_DELAY)

            resp = self.session.get(uri,
                                    params=params,
                                    verify=self.verify_ssl,
                                    timeout=self.timeout)
            if resp.status_code == 200:
                break

            if n == 0:
                raise CIFBusy('system seems busy.. try again later')

        data = resp.content

        s = (int(resp.headers['Content-Length']) / 1024 / 1024)
        logger.info('processing %.2f megs' % s)

        msgs = json.loads(data.decode('utf-8'))

        if msgs.get('data') and msgs['data'] == '{}':
            msgs['data'] = []

        if msgs.get('data') and isinstance(
                msgs['data'], basestring) and msgs['data'].startswith(
                    '{"hits":{"hits":[{"_source":'):
            msgs['data'] = json.loads(msgs['data'])
            msgs['data'] = [r['_source'] for r in msgs['data']['hits']['hits']]

        if not msgs.get('status') and not msgs.get('message') == 'success':
            raise RuntimeError(msgs)

        if msgs.get('status') and msgs['status'] == 'failed':
            raise InvalidSearch(msgs['message'])

        if isinstance(msgs.get('data'), list):
            for m in msgs['data']:
                if m.get('message'):
                    try:
                        m['message'] = b64decode(m['message'])
                    except Exception as e:
                        pass
        return msgs
Ejemplo n.º 24
0
    def _filter_indicator(self, filters, s):

        for k, v in list(filters.items()):
            if k not in VALID_FILTERS:
                del filters[k]

        if not filters.get('indicator'):
            return s

        i = filters.pop('indicator')

        try:
            itype = resolve_itype(i)
        except TypeError as e:
            logger.error(e)
            s = s.join(Message).filter(Indicator.Message.like('%{}%'.format(i)))
            return s

        if itype == 'email':
            s = s.join(Email).filter(or_(
                    Email.email.like('%.{}'.format(i)),
                    Email.email == i)
            )
            return s

        if itype == 'ipv4':
            ip = ipaddress.IPv4Network(i)
            mask = ip.prefixlen

            if mask < 8:
                raise InvalidSearch('prefix needs to be >= 8')

            start = str(ip.network_address)
            end = str(ip.broadcast_address)

            logger.debug('{} - {}'.format(start, end))

            s = s.join(Ipv4).filter(Ipv4.ipv4 >= start)
            s = s.filter(Ipv4.ipv4 <= end)

            return s

        if itype == 'ipv6':
            ip = ipaddress.IPv6Network(i)
            mask = ip.prefixlen

            if mask < 32:
                raise InvalidSearch('prefix needs to be >= 32')

            start = str(ip.network_address)
            end = str(ip.broadcast_address)

            logger.debug('{} - {}'.format(start, end))

            s = s.join(Ipv6).filter(Ipv6.ip >= start)
            s = s.filter(Ipv6.ip <= end)
            return s

        if itype == 'fqdn':
            s = s.join(Fqdn).filter(or_(
                    Fqdn.fqdn.like('%.{}'.format(i)),
                    Fqdn.fqdn == i)
            )
            return s

        if itype == 'url':
            s = s.join(Url).filter(Url.url == i)
            return s

        if itype in HASH_TYPES:
            s = s.join(Hash).filter(Hash.hash == str(i))
            return s

        raise ValueError
Ejemplo n.º 25
0
    def _filter_terms(self, filters, s):

        for k, v in filters.items():
            if k in ['nolog', 'days', 'hours', 'groups', 'limit', 'feed']:
                continue

            if k == 'reported_at':
                if ',' in v:
                    start, end = v.split(',')
                    s = s.filter(
                        Indicator.reported_at >= arrow.get(start).datetime)
                    s = s.filter(
                        Indicator.reported_at <= arrow.get(end).datetime)
                else:
                    s = s.filter(
                        Indicator.reported_at >= arrow.get(v).datetime)

            elif k == 'tags':
                t = v.split(',')
                s = s.outerjoin(Tag)
                s = s.filter(or_(Tag.tag == tt for tt in t))

            elif k == 'confidence':
                if ',' in str(v):
                    start, end = str(v).split(',')
                    s = s.filter(Indicator.confidence >= float(start))
                    s = s.filter(Indicator.confidence <= float(end))
                else:
                    s = s.filter(Indicator.confidence >= float(v))

            elif k == 'probability':
                if ',' in str(v):
                    start, end = str(v).split(',')
                    if start == 0:
                        s = s.filter(
                            or_(Indicator.probability >= float(start),
                                Indicator.probability == None))
                        s = s.filter(Indicator.probability <= float(end))
                    else:
                        s = s.filter(Indicator.probability >= float(start))
                        s = s.filter(Indicator.probability <= float(end))
                else:
                    if float(v) == 0:
                        s = s.filter(
                            or_(Indicator.probability == None,
                                Indicator.probability >= float(v)))
                    else:
                        s = s.filter(Indicator.probability >= float(v))

            elif k == 'itype':
                s = s.filter(Indicator.itype == v)

            elif k == 'provider':
                s = s.filter(Indicator.provider == v)

            elif k == 'asn':
                s = s.filter(Indicator.asn == v)

            elif k == 'asn_desc':
                s = s.filter(Indicator.asn_desc.like('%{}%'.format(v)))

            elif k == 'cc':
                s = s.filter(Indicator.cc == v)

            elif k == 'rdata':
                s = s.filter(Indicator.rdata == v)

            elif k == 'region':
                s = s.filter(Indicator.region == v)

            elif k == 'related':
                s = s.filter(Indicator.related == v)

            elif k == 'uuid':
                s = s.filter(Indicator.uuid == v)

            else:
                raise InvalidSearch('invalid filter: %s' % k)

        return s
    def handle_indicators_search(self, token, data, **kwargs):
        t = self.store.tokens.read(token)

        if PYVERSION == 2:
            if data.get('indicator'):
                if isinstance(data['indicator'], str):
                    data['indicator'] = unicode(data['indicator'])

        # verify group filter matches token permissions
        if data.get('groups'):
            q_groups = [g.strip() for g in data['groups'].split(',')]

            gg = []
            for g in q_groups:
                if g in t['groups']:
                    gg.append(g)

            if gg:
                data['groups'] = gg
            else:
                data['groups'] = '{}'

        if not data.get('reporttime'):
            if data.get('days'):
                now = arrow.utcnow()
                data['reporttimeend'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))
                now = now.shift(days=-int(data['days']))
                data['reporttime'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))

            if data.get('hours'):
                now = arrow.utcnow()
                data['reporttimeend'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))
                now = now.shift(hours=-int(data['hours']))
                data['reporttime'] = '{0}Z'.format(
                    now.format('YYYY-MM-DDTHH:mm:ss'))

        s = time.time()

        self._log_search(t, data)

        try:
            x = self.store.indicators.search(t, data)
            logger.debug('done')
        except Exception as e:
            logger.error(e)

            if logger.getEffectiveLevel() == logging.DEBUG:
                logger.error(traceback.print_exc())

            raise InvalidSearch(': {}'.format(e))

        logger.debug('%s' % (time.time() - s))

        # for xx in x:
        #     if xx.get('message'):
        #         xx['message'] = b64encode(xx['message']).encode('utf-8')

        return x
Ejemplo n.º 27
0
    def indicators_search(self, filters, sort=None, raw=False):
        # build filters with elasticsearch-dsl
        # http://elasticsearch-dsl.readthedocs.org/en/latest/search_dsl.html

        limit = filters.get('limit')
        if limit:
            del filters['limit']
        else:
            limit = LIMIT

        nolog = filters.get('nolog')
        if nolog:
            del filters['nolog']

        timeout = TIMEOUT

        s = Indicator.search(index='indicators-*')
        s = s.params(size=limit, timeout=timeout)
        if sort:
            s = s.sort(sort)

        q_filters = {}
        for f in VALID_FILTERS:
            if filters.get(f):
                q_filters[f] = filters[f]

        if q_filters.get('indicator'):
            try:
                itype = resolve_itype(q_filters['indicator'])

                if itype == 'ipv4':
                    if PYVERSION == 2:
                        q_filters['indicator'] = unicode(
                            q_filters['indicator'])

                    ip = ipaddress.IPv4Network(q_filters['indicator'])
                    mask = ip.prefixlen
                    if mask < 8:
                        raise InvalidSearch(
                            'prefix needs to be greater than or equal to 8')
                    start = str(ip.network_address)
                    end = str(ip.broadcast_address)

                    s = s.filter('range',
                                 indicator_ipv4={
                                     'gte': start,
                                     'lte': end
                                 })
                elif itype is 'ipv6':
                    if PYVERSION == 2:
                        q_filters['indicator'] = unicode(
                            q_filters['indicator'])

                    ip = ipaddress.IPv6Network(q_filters['indicator'])
                    mask = ip.prefixlen
                    if mask < 32:
                        raise InvalidSearch(
                            'prefix needs to be greater than or equal to 32')

                    start = binascii.b2a_hex(
                        socket.inet_pton(
                            socket.AF_INET6,
                            str(ip.network_address))).decode('utf-8')
                    end = binascii.b2a_hex(
                        socket.inet_pton(
                            socket.AF_INET6,
                            str(ip.broadcast_address))).decode('utf-8')

                    s = s.filter('range',
                                 indicator_ipv6={
                                     'gte': start,
                                     'lte': end
                                 })

                elif itype in ('email', 'url', 'fqdn'):
                    s = s.filter('term', indicator=q_filters['indicator'])

            except InvalidIndicator:
                s = s.query("match", message=q_filters['indicator'])

            del q_filters['indicator']

        for f in q_filters:
            kwargs = {f: q_filters[f]}
            s = s.filter('term', **kwargs)

        try:
            rv = s.execute()
        except elasticsearch.exceptions.RequestError as e:
            self.logger.error(e)
            return []

        if raw:
            try:
                return rv.hits.hits
            except KeyError:
                return []
        else:
            try:
                data = []
                for x in rv.hits.hits:
                    if x['_source'].get('message'):
                        x['_source']['message'] = b64encode(
                            x['_source']['message'].encode('utf-8'))
                    data.append(x['_source'])
                return data
            except KeyError as e:
                self.logger.error(e)
                return []