Esempio n. 1
0
    def test_valid_indicator(self, i):
        if isinstance(i, Indicator):
            i = i.__dict__()

        for f in REQUIRED_FIELDS:
            if not i.get(f):
                raise InvalidIndicator("Missing required field: {} for \n{}".format(f, i))
Esempio n. 2
0
    def upsert(self, token, data, **kwargs):
        if type(data) == dict:
            data = [data]

        s = self.handle()

        n = 0
        tmp_added = {}

        for d in data:
            logger.debug(d)

            if not d.get('group'):
                raise InvalidIndicator('missing group')

            if isinstance(d['group'], list):
                d['group'] = d['group'][0]

            # raises AuthError if invalid group
            self._check_token_groups(token, d)

            if PYVERSION == 2:
                if isinstance(d['indicator'], str):
                    d['indicator'] = unicode(d['indicator'])

            self.test_valid_indicator(d)

            tags = d.get("tags", [])
            if len(tags) > 0:
                if isinstance(tags, basestring):
                    tags = tags.split(',')

                del d['tags']

            i = s.query(Indicator).options(lazyload('*')).filter_by(
                provider=d['provider'],
                itype=d['itype'],
                indicator=d['indicator'],
            ).order_by(Indicator.lasttime.desc())

            if d.get('rdata'):
                i = i.filter_by(rdata=d['rdata'])

            if d['itype'] == 'ipv4':
                match = re.search('^(\S+)\/(\d+)$',
                                  d['indicator'])  # TODO -- use ipaddress
                if match:
                    i = i.join(Ipv4).filter(Ipv4.ipv4 == match.group(1),
                                            Ipv4.mask == match.group(2))
                else:
                    i = i.join(Ipv4).filter(Ipv4.ipv4 == d['indicator'])

            if d['itype'] == 'ipv6':
                match = re.search('^(\S+)\/(\d+)$',
                                  d['indicator'])  # TODO -- use ipaddress
                if match:
                    i = i.join(Ipv6).filter(Ipv6.ip == match.group(1),
                                            Ipv6.mask == match.group(2))
                else:
                    i = i.join(Ipv6).filter(Ipv6.ip == d['indicator'])

            if d['itype'] == 'fqdn':
                i = i.join(Fqdn).filter(Fqdn.fqdn == d['indicator'])

            if d['itype'] == 'url':
                i = i.join(Url).filter(Url.url == d['indicator'])

            if d['itype'] in HASH_TYPES:
                i = i.join(Hash).filter(Hash.hash == d['indicator'])

            if len(tags):
                i = i.join(Tag).filter(Tag.tag == tags[0])

            r = i.first()

            if r:
                if d.get('lasttime') and arrow.get(
                        d['lasttime']).datetime > arrow.get(
                            r.lasttime).datetime:
                    logger.debug('{} {}'.format(
                        arrow.get(r.lasttime).datetime,
                        arrow.get(d['lasttime']).datetime))
                    logger.debug('upserting: %s' % d['indicator'])

                    r.count += 1
                    r.lasttime = arrow.get(
                        d['lasttime']).datetime.replace(tzinfo=None)

                    if not d.get('reporttime'):
                        d['reporttime'] = arrow.utcnow().datetime

                    r.reporttime = arrow.get(
                        d['reporttime']).datetime.replace(tzinfo=None)

                    if d.get('message'):
                        try:
                            d['message'] = b64decode(d['message'])
                        except Exception as e:
                            pass
                        m = Message(message=d['message'], indicator=r)
                        s.add(m)

                    n += 1
                else:
                    logger.debug('skipping: %s' % d['indicator'])
            else:
                if tmp_added.get(d['indicator']):
                    if d.get('lasttime') in tmp_added[d['indicator']]:
                        logger.debug('skipping: %s' % d['indicator'])
                        continue
                else:
                    tmp_added[d['indicator']] = set()

                if not d.get('lasttime'):
                    d['lasttime'] = arrow.utcnow().datetime.replace(
                        tzinfo=None)

                if not d.get('reporttime'):
                    d['reporttime'] = arrow.utcnow().datetime.replace(
                        tzinfo=None)

                if PYVERSION == 2:
                    d['lasttime'] = arrow.get(
                        d['lasttime']).datetime.replace(tzinfo=None)
                    d['reporttime'] = arrow.get(
                        d['reporttime']).datetime.replace(tzinfo=None)

                if not d.get('firsttime'):
                    d['firsttime'] = d['lasttime']

                ii = Indicator(**d)
                s.add(ii)

                itype = resolve_itype(d['indicator'])

                if itype is 'ipv4':
                    match = re.search('^(\S+)\/(\d+)$',
                                      d['indicator'])  # TODO -- use ipaddress
                    if match:
                        ipv4 = Ipv4(ipv4=match.group(1),
                                    mask=match.group(2),
                                    indicator=ii)
                    else:
                        ipv4 = Ipv4(ipv4=d['indicator'], indicator=ii)

                    s.add(ipv4)

                elif itype is 'ipv6':
                    match = re.search('^(\S+)\/(\d+)$',
                                      d['indicator'])  # TODO -- use ipaddress
                    if match:
                        ip = Ipv6(ip=match.group(1),
                                  mask=match.group(2),
                                  indicator=ii)
                    else:
                        ip = Ipv6(ip=d['indicator'], indicator=ii)

                    s.add(ip)

                if itype is 'fqdn':
                    fqdn = Fqdn(fqdn=d['indicator'], indicator=ii)
                    s.add(fqdn)

                if itype is 'url':
                    url = Url(url=d['indicator'], indicator=ii)
                    s.add(url)

                if itype in HASH_TYPES:
                    h = Hash(hash=d['indicator'], indicator=ii)
                    s.add(h)

                for t in tags:
                    t = Tag(tag=t, indicator=ii)
                    s.add(t)

                if d.get('message'):
                    try:
                        d['message'] = b64decode(d['message'])
                    except Exception as e:
                        pass

                    m = Message(message=d['message'], indicator=ii)
                    s.add(m)

                n += 1
                tmp_added[d['indicator']].add(d['lasttime'])

            # if we're in testing mode, this needs re-attaching since we've manipulated the dict for Indicator()
            # see test_store_sqlite
            d['tags'] = ','.join(tags)

            logger.debug('committing')
            start = time.time()
            try:
                s.commit()
            except Exception as e:
                n = 0
                logger.error(e)
                logger.debug('rolling back transaction..')
                s.rollback()

        logger.debug('done: %0.2f' % (time.time() - start))
        return n
    def _check_token_groups(self, t, i):
        if not i.get('group'):
            raise InvalidIndicator('missing group')

        if i['group'] not in t['groups']:
            raise AuthError('unable to write to %s' % i['group'])
Esempio n. 4
0
    def upsert_indicators(self, s, n, d, token, tmp_added, batch):
        try:
            n += 1
            if not d.get('group'):
                raise InvalidIndicator('missing group')

            if isinstance(d['group'], list):
                d['group'] = d['group'][0]

            # raises AuthError if invalid group
            self._check_token_groups(token, d)

            if PYVERSION == 2:
                if isinstance(d['indicator'], str):
                    d['indicator'] = unicode(d['indicator'])

            self.test_valid_indicator(d)

            tags = d.get("tags", [])
            if len(tags) > 0:
                if isinstance(tags, basestring):
                    tags = tags.split(',')

                del d['tags']

            i = s.query(Indicator).options(lazyload('*')).filter_by(
                provider=d['provider'],
                itype=d['itype'],
                indicator=d['indicator']).order_by(Indicator.lasttime.desc())

            if d.get('rdata'):
                i = i.filter_by(rdata=d['rdata'])

            if d['itype'] == 'ipv4':
                match = re.search('^(\S+)\/(\d+)$',
                                  d['indicator'])  # TODO -- use ipaddress
                if match:
                    i = i.join(Ipv4).filter(Ipv4.ipv4 == match.group(1),
                                            Ipv4.mask == match.group(2))
                else:
                    i = i.join(Ipv4).filter(Ipv4.ipv4 == d['indicator'])

            if d['itype'] == 'ipv6':
                match = re.search('^(\S+)\/(\d+)$',
                                  d['indicator'])  # TODO -- use ipaddress
                if match:
                    i = i.join(Ipv6).filter(Ipv6.ip == match.group(1),
                                            Ipv6.mask == match.group(2))
                else:
                    i = i.join(Ipv6).filter(Ipv6.ip == d['indicator'])

            if d['itype'] == 'fqdn':
                i = i.join(Fqdn).filter(Fqdn.fqdn == d['indicator'])

            if d['itype'] == 'url':
                i = i.join(Url).filter(Url.url == d['indicator'])

            if d['itype'] in HASH_TYPES:
                i = i.join(Hash).filter(Hash.hash == d['indicator'])

            if len(tags):
                i = i.join(Tag).filter(Tag.tag == tags[0])

            r = i.first()

            if r:
                if not d.get('lasttime') or d.get('lasttime') == None:
                    # If no lasttime submitted, presume a lasttime value of now
                    d['lasttime'] = arrow.utcnow().datetime

                if d.get('lasttime') and arrow.get(
                        d['lasttime']).datetime > arrow.get(
                            r.lasttime).datetime:
                    logger.debug('{} {}'.format(
                        arrow.get(r.lasttime).datetime,
                        arrow.get(d['lasttime']).datetime))
                    logger.debug('upserting: %s' % d['indicator'])

                    r.count += 1
                    r.lasttime = arrow.get(
                        d['lasttime']).datetime.replace(tzinfo=None)

                    if not d.get('reporttime'):
                        d['reporttime'] = arrow.utcnow().datetime

                    r.reporttime = arrow.get(
                        d['reporttime']).datetime.replace(tzinfo=None)

                    if d.get('message'):
                        try:
                            d['message'] = b64decode(d['message'])
                        except Exception as e:
                            pass
                        m = Message(message=d['message'], indicator=r)
                        s.add(m)

                else:
                    logger.debug('skipping: %s' % d['indicator'])
                    n -= 1
            else:
                if tmp_added.get(d['indicator']):
                    if d.get('lasttime') in tmp_added[d['indicator']]:
                        logger.debug('skipping: %s' % d['indicator'])
                        n -= 1
                        return n
                else:
                    tmp_added[d['indicator']] = set()

                if not d.get('lasttime'):
                    d['lasttime'] = arrow.utcnow().datetime.replace(
                        tzinfo=None)

                if not d.get('reporttime'):
                    d['reporttime'] = arrow.utcnow().datetime.replace(
                        tzinfo=None)

                if PYVERSION == 2:
                    d['lasttime'] = arrow.get(
                        d['lasttime']).datetime.replace(tzinfo=None)
                    d['reporttime'] = arrow.get(
                        d['reporttime']).datetime.replace(tzinfo=None)

                if not d.get('firsttime'):
                    d['firsttime'] = d['lasttime']

                ii = Indicator(**d)
                logger.debug('inserting: %s' % d['indicator'])
                s.add(ii)

                itype = resolve_itype(d['indicator'])

                if itype is 'ipv4':
                    match = re.search('^(\S+)\/(\d+)$',
                                      d['indicator'])  # TODO -- use ipaddress
                    if match:
                        ipv4 = Ipv4(ipv4=match.group(1),
                                    mask=match.group(2),
                                    indicator=ii)
                    else:
                        ipv4 = Ipv4(ipv4=d['indicator'], indicator=ii)

                    s.add(ipv4)

                elif itype is 'ipv6':
                    match = re.search('^(\S+)\/(\d+)$',
                                      d['indicator'])  # TODO -- use ipaddress
                    if match:
                        ip = Ipv6(ip=match.group(1),
                                  mask=match.group(2),
                                  indicator=ii)
                    else:
                        ip = Ipv6(ip=d['indicator'], indicator=ii)

                    s.add(ip)

                if itype is 'fqdn':
                    fqdn = Fqdn(fqdn=d['indicator'], indicator=ii)
                    s.add(fqdn)

                if itype is 'url':
                    url = Url(url=d['indicator'], indicator=ii)
                    s.add(url)

                if itype in HASH_TYPES:
                    h = Hash(hash=d['indicator'], indicator=ii)
                    s.add(h)

                for t in tags:
                    t = Tag(tag=t, indicator=ii)
                    s.add(t)

                if d.get('message'):
                    try:
                        d['message'] = b64decode(d['message'])
                    except Exception as e:
                        pass

                    m = Message(message=d['message'], indicator=ii)
                    s.add(m)

                tmp_added[d['indicator']].add(d['lasttime'])

            # if we're in testing mode, this needs re-attaching since we've manipulated the dict for Indicator()
            # see test_store_sqlite
            d['tags'] = ','.join(tags)

        except Exception as e:
            logger.error(e)
            if batch:
                logger.debug(
                    'Failing batch - passing exception to upper layer')
                raise
            else:
                n -= 1
                logger.debug('Rolling back individual transaction..')
                s.rollback()

        # When this function is called in non-batch mode, we need to commit each individual indicator at this point.
        # For batches, the commit happens at a higher layer.

        if not batch:
            try:
                logger.debug('Committing individual indicator')
                start = time.time()
                s.commit()
            except Exception as e:
                n -= 1
                logger.error(e)
                logger.debug('Rolling back individual transaction..')
                s.rollback()

        return n