def test_valid_indicator(self, i): if isinstance(i, Indicator): i = i.__dict__() for f in REQUIRED_FIELDS: if not i.get(f): raise InvalidIndicator("Missing required field: {} for \n{}".format(f, i))
def upsert(self, token, data, **kwargs): if type(data) == dict: data = [data] s = self.handle() n = 0 tmp_added = {} for d in data: logger.debug(d) if not d.get('group'): raise InvalidIndicator('missing group') if isinstance(d['group'], list): d['group'] = d['group'][0] # raises AuthError if invalid group self._check_token_groups(token, d) if PYVERSION == 2: if isinstance(d['indicator'], str): d['indicator'] = unicode(d['indicator']) self.test_valid_indicator(d) tags = d.get("tags", []) if len(tags) > 0: if isinstance(tags, basestring): tags = tags.split(',') del d['tags'] i = s.query(Indicator).options(lazyload('*')).filter_by( provider=d['provider'], itype=d['itype'], indicator=d['indicator'], ).order_by(Indicator.lasttime.desc()) if d.get('rdata'): i = i.filter_by(rdata=d['rdata']) if d['itype'] == 'ipv4': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: i = i.join(Ipv4).filter(Ipv4.ipv4 == match.group(1), Ipv4.mask == match.group(2)) else: i = i.join(Ipv4).filter(Ipv4.ipv4 == d['indicator']) if d['itype'] == 'ipv6': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: i = i.join(Ipv6).filter(Ipv6.ip == match.group(1), Ipv6.mask == match.group(2)) else: i = i.join(Ipv6).filter(Ipv6.ip == d['indicator']) if d['itype'] == 'fqdn': i = i.join(Fqdn).filter(Fqdn.fqdn == d['indicator']) if d['itype'] == 'url': i = i.join(Url).filter(Url.url == d['indicator']) if d['itype'] in HASH_TYPES: i = i.join(Hash).filter(Hash.hash == d['indicator']) if len(tags): i = i.join(Tag).filter(Tag.tag == tags[0]) r = i.first() if r: if d.get('lasttime') and arrow.get( d['lasttime']).datetime > arrow.get( r.lasttime).datetime: logger.debug('{} {}'.format( arrow.get(r.lasttime).datetime, arrow.get(d['lasttime']).datetime)) logger.debug('upserting: %s' % d['indicator']) r.count += 1 r.lasttime = arrow.get( d['lasttime']).datetime.replace(tzinfo=None) if not d.get('reporttime'): d['reporttime'] = arrow.utcnow().datetime r.reporttime = arrow.get( d['reporttime']).datetime.replace(tzinfo=None) if d.get('message'): try: d['message'] = b64decode(d['message']) except Exception as e: pass m = Message(message=d['message'], indicator=r) s.add(m) n += 1 else: logger.debug('skipping: %s' % d['indicator']) else: if tmp_added.get(d['indicator']): if d.get('lasttime') in tmp_added[d['indicator']]: logger.debug('skipping: %s' % d['indicator']) continue else: tmp_added[d['indicator']] = set() if not d.get('lasttime'): d['lasttime'] = arrow.utcnow().datetime.replace( tzinfo=None) if not d.get('reporttime'): d['reporttime'] = arrow.utcnow().datetime.replace( tzinfo=None) if PYVERSION == 2: d['lasttime'] = arrow.get( d['lasttime']).datetime.replace(tzinfo=None) d['reporttime'] = arrow.get( d['reporttime']).datetime.replace(tzinfo=None) if not d.get('firsttime'): d['firsttime'] = d['lasttime'] ii = Indicator(**d) s.add(ii) itype = resolve_itype(d['indicator']) if itype is 'ipv4': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: ipv4 = Ipv4(ipv4=match.group(1), mask=match.group(2), indicator=ii) else: ipv4 = Ipv4(ipv4=d['indicator'], indicator=ii) s.add(ipv4) elif itype is 'ipv6': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: ip = Ipv6(ip=match.group(1), mask=match.group(2), indicator=ii) else: ip = Ipv6(ip=d['indicator'], indicator=ii) s.add(ip) if itype is 'fqdn': fqdn = Fqdn(fqdn=d['indicator'], indicator=ii) s.add(fqdn) if itype is 'url': url = Url(url=d['indicator'], indicator=ii) s.add(url) if itype in HASH_TYPES: h = Hash(hash=d['indicator'], indicator=ii) s.add(h) for t in tags: t = Tag(tag=t, indicator=ii) s.add(t) if d.get('message'): try: d['message'] = b64decode(d['message']) except Exception as e: pass m = Message(message=d['message'], indicator=ii) s.add(m) n += 1 tmp_added[d['indicator']].add(d['lasttime']) # if we're in testing mode, this needs re-attaching since we've manipulated the dict for Indicator() # see test_store_sqlite d['tags'] = ','.join(tags) logger.debug('committing') start = time.time() try: s.commit() except Exception as e: n = 0 logger.error(e) logger.debug('rolling back transaction..') s.rollback() logger.debug('done: %0.2f' % (time.time() - start)) return n
def _check_token_groups(self, t, i): if not i.get('group'): raise InvalidIndicator('missing group') if i['group'] not in t['groups']: raise AuthError('unable to write to %s' % i['group'])
def upsert_indicators(self, s, n, d, token, tmp_added, batch): try: n += 1 if not d.get('group'): raise InvalidIndicator('missing group') if isinstance(d['group'], list): d['group'] = d['group'][0] # raises AuthError if invalid group self._check_token_groups(token, d) if PYVERSION == 2: if isinstance(d['indicator'], str): d['indicator'] = unicode(d['indicator']) self.test_valid_indicator(d) tags = d.get("tags", []) if len(tags) > 0: if isinstance(tags, basestring): tags = tags.split(',') del d['tags'] i = s.query(Indicator).options(lazyload('*')).filter_by( provider=d['provider'], itype=d['itype'], indicator=d['indicator']).order_by(Indicator.lasttime.desc()) if d.get('rdata'): i = i.filter_by(rdata=d['rdata']) if d['itype'] == 'ipv4': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: i = i.join(Ipv4).filter(Ipv4.ipv4 == match.group(1), Ipv4.mask == match.group(2)) else: i = i.join(Ipv4).filter(Ipv4.ipv4 == d['indicator']) if d['itype'] == 'ipv6': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: i = i.join(Ipv6).filter(Ipv6.ip == match.group(1), Ipv6.mask == match.group(2)) else: i = i.join(Ipv6).filter(Ipv6.ip == d['indicator']) if d['itype'] == 'fqdn': i = i.join(Fqdn).filter(Fqdn.fqdn == d['indicator']) if d['itype'] == 'url': i = i.join(Url).filter(Url.url == d['indicator']) if d['itype'] in HASH_TYPES: i = i.join(Hash).filter(Hash.hash == d['indicator']) if len(tags): i = i.join(Tag).filter(Tag.tag == tags[0]) r = i.first() if r: if not d.get('lasttime') or d.get('lasttime') == None: # If no lasttime submitted, presume a lasttime value of now d['lasttime'] = arrow.utcnow().datetime if d.get('lasttime') and arrow.get( d['lasttime']).datetime > arrow.get( r.lasttime).datetime: logger.debug('{} {}'.format( arrow.get(r.lasttime).datetime, arrow.get(d['lasttime']).datetime)) logger.debug('upserting: %s' % d['indicator']) r.count += 1 r.lasttime = arrow.get( d['lasttime']).datetime.replace(tzinfo=None) if not d.get('reporttime'): d['reporttime'] = arrow.utcnow().datetime r.reporttime = arrow.get( d['reporttime']).datetime.replace(tzinfo=None) if d.get('message'): try: d['message'] = b64decode(d['message']) except Exception as e: pass m = Message(message=d['message'], indicator=r) s.add(m) else: logger.debug('skipping: %s' % d['indicator']) n -= 1 else: if tmp_added.get(d['indicator']): if d.get('lasttime') in tmp_added[d['indicator']]: logger.debug('skipping: %s' % d['indicator']) n -= 1 return n else: tmp_added[d['indicator']] = set() if not d.get('lasttime'): d['lasttime'] = arrow.utcnow().datetime.replace( tzinfo=None) if not d.get('reporttime'): d['reporttime'] = arrow.utcnow().datetime.replace( tzinfo=None) if PYVERSION == 2: d['lasttime'] = arrow.get( d['lasttime']).datetime.replace(tzinfo=None) d['reporttime'] = arrow.get( d['reporttime']).datetime.replace(tzinfo=None) if not d.get('firsttime'): d['firsttime'] = d['lasttime'] ii = Indicator(**d) logger.debug('inserting: %s' % d['indicator']) s.add(ii) itype = resolve_itype(d['indicator']) if itype is 'ipv4': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: ipv4 = Ipv4(ipv4=match.group(1), mask=match.group(2), indicator=ii) else: ipv4 = Ipv4(ipv4=d['indicator'], indicator=ii) s.add(ipv4) elif itype is 'ipv6': match = re.search('^(\S+)\/(\d+)$', d['indicator']) # TODO -- use ipaddress if match: ip = Ipv6(ip=match.group(1), mask=match.group(2), indicator=ii) else: ip = Ipv6(ip=d['indicator'], indicator=ii) s.add(ip) if itype is 'fqdn': fqdn = Fqdn(fqdn=d['indicator'], indicator=ii) s.add(fqdn) if itype is 'url': url = Url(url=d['indicator'], indicator=ii) s.add(url) if itype in HASH_TYPES: h = Hash(hash=d['indicator'], indicator=ii) s.add(h) for t in tags: t = Tag(tag=t, indicator=ii) s.add(t) if d.get('message'): try: d['message'] = b64decode(d['message']) except Exception as e: pass m = Message(message=d['message'], indicator=ii) s.add(m) tmp_added[d['indicator']].add(d['lasttime']) # if we're in testing mode, this needs re-attaching since we've manipulated the dict for Indicator() # see test_store_sqlite d['tags'] = ','.join(tags) except Exception as e: logger.error(e) if batch: logger.debug( 'Failing batch - passing exception to upper layer') raise else: n -= 1 logger.debug('Rolling back individual transaction..') s.rollback() # When this function is called in non-batch mode, we need to commit each individual indicator at this point. # For batches, the commit happens at a higher layer. if not batch: try: logger.debug('Committing individual indicator') start = time.time() s.commit() except Exception as e: n -= 1 logger.error(e) logger.debug('Rolling back individual transaction..') s.rollback() return n