async def test_EVENT_CACHE_NEW(system): handler = TestEventHandler() await system.register_event_handler(EVENT_CACHE_NEW, handler) amt = AnalysisModuleType(name="test", description="", cache_ttl=600) root = system.new_root() observable = root.add_observable("test", "test") request = observable.create_analysis_request(amt) request.initialize_result() analysis = request.modified_observable.add_analysis(type=amt) assert await system.cache_analysis_result(request) is not None await handler.wait() assert handler.event.name == EVENT_CACHE_NEW assert handler.event.args[0] == generate_cache_key(observable, amt) assert AnalysisRequest.from_dict(handler.event.args[1], system) == request # we can potentially see duplicate cache hits handler = TestEventHandler() await system.register_event_handler(EVENT_CACHE_NEW, handler) assert await system.cache_analysis_result(request) is not None await handler.wait() assert handler.event.name == EVENT_CACHE_NEW assert handler.event.args[0] == generate_cache_key(observable, amt) assert AnalysisRequest.from_dict(handler.event.args[1], system) == request
async def test_analysis_request_serialization(system): root = system.new_root() observable = root.add_observable("test", "1.2.3.4") request = observable.create_analysis_request(amt) assert request == AnalysisRequest.from_dict(request.to_dict(), system) assert request == AnalysisRequest.from_json(request.to_json(), system) other = AnalysisRequest.from_dict(request.to_dict(), system) assert request.id == other.id assert request.observable == other.observable assert request.type == other.type assert request.status == other.status assert request.owner == other.owner assert request.original_root == other.original_root assert request.modified_root == other.modified_root other = AnalysisRequest.from_json(request.to_json(), system) assert request.id == other.id assert request.observable == other.observable assert request.type == other.type assert request.status == other.status assert request.owner == other.owner assert request.original_root == other.original_root assert request.modified_root == other.modified_root
async def test_EVENT_WORK_ADD(system): handler = TestEventHandler() await system.register_event_handler(EVENT_WORK_ADD, handler) amt = AnalysisModuleType("test", "") await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", "test") request = AnalysisRequest(system, root, observable, amt) await system.queue_analysis_request(request) await handler.wait() assert handler.event.name == EVENT_WORK_ADD assert handler.event.args[0] == amt.name assert AnalysisRequest.from_dict(handler.event.args[1], system) == request
async def test_EVENT_CACHE_HIT(system): handler = TestEventHandler() await system.register_event_handler(EVENT_CACHE_HIT, handler) amt = AnalysisModuleType("test", "", cache_ttl=60) await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", "test") root_request = root.create_analysis_request() await system.process_analysis_request(root_request) request = await system.get_next_analysis_request("owner", amt, 0) request.initialize_result() request.modified_observable.add_analysis(type=amt, details={"test": "test"}) await system.process_analysis_request(request) assert handler.event is None root = system.new_root() observable = root.add_observable("test", "test") root_request = root.create_analysis_request() await system.process_analysis_request(root_request) await handler.wait() assert handler.event.name == EVENT_CACHE_HIT event_root = RootAnalysis.from_dict(handler.event.args[0], system) assert event_root.uuid == root.uuid and event_root.version is not None assert handler.event.args[1]["type"] == observable.type assert handler.event.args[1]["value"] == observable.value assert isinstance(AnalysisRequest.from_dict(handler.event.args[2], system), AnalysisRequest)
async def queue_analysis_request(self, ar: AnalysisRequest): """Submits the given AnalysisRequest to the appropriate queue for analysis.""" assert isinstance(ar, AnalysisRequest) assert isinstance(ar.root, RootAnalysis) ar.owner = None ar.status = TRACKING_STATUS_QUEUED await self.unlock_analysis_request(ar) await self.track_analysis_request(ar) # if this is a RootAnalysis request then we just process it here (there is no inbound queue for root analysis) if ar.is_root_analysis_request or ar.is_observable_analysis_result: return await self.process_analysis_request(ar) # otherwise we assign this request to the appropriate work queue based on the amt await self.put_work(ar.type, ar)
async def test_get_next_analysis_request_expired(system): amt = AnalysisModuleType( name="test", description="test", version="1.0.0", timeout=0, cache_ttl=600 # immediately expire ) await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", TEST_1) request = AnalysisRequest(system, root, observable, amt) await system.queue_analysis_request(request) next_ar = await system.get_next_analysis_request(TEST_OWNER, amt, 0) assert next_ar == request assert next_ar.status == TRACKING_STATUS_ANALYZING assert next_ar.owner == TEST_OWNER # this next call should trigger the move of the expired analysis request # and since it expires right away we should see the same request again next_ar = await system.get_next_analysis_request(TEST_OWNER, amt, 0) assert next_ar == request # execute this manually await system.process_expired_analysis_requests(amt) # should be back in the queue request = await system.get_analysis_request_by_request_id(request.id) assert request.status == TRACKING_STATUS_QUEUED assert request.owner is None # and then we should get it again next_ar = await system.get_next_analysis_request(TEST_OWNER, amt, 0) assert next_ar == request
async def get_next_analysis_request( self, owner_uuid: str, amt: Union[AnalysisModuleType, str], timeout: Optional[int] = 0, version: Optional[str] = None, extended_version: Optional[dict[str, str]] = [], ) -> Union[AnalysisRequest, None]: if isinstance(amt, AnalysisModuleType): version = amt.version extended_version = amt.extended_version amt = amt.name async with self.get_client() as client: response = await client.post( "/work_queue", json=AnalysisRequestQueryModel( owner=owner_uuid, amt=amt, timeout=timeout, version=version, extended_version=extended_version, ).dict(), ) _raise_exception_on_error(response) if response.status_code == 204: return None else: return AnalysisRequest.from_dict(response.json(), self.system)
async def i_get_analysis_requests_by_root( self, key: str) -> list[AnalysisRequest]: async with self.get_db() as db: return [ AnalysisRequest.from_dict(json.loads(_[0].json_data), self) for _ in (await db.execute( select(AnalysisRequestTracking).where( AnalysisRequestTracking.root_uuid == key))).all() ]
async def i_get_expired_analysis_requests(self) -> list[AnalysisRequest]: async with self.get_db() as db: result = (await db.execute( select(AnalysisRequestTracking).where(datetime.datetime.now( ) > AnalysisRequestTracking.expiration_date))).all() return [ AnalysisRequest.from_dict(json.loads(_[0].json_data), self) for _ in result ]
async def test_cached_analysis_result(system): amt = AnalysisModuleType(ANALYSIS_TYPE_TEST, "blah", cache_ttl=60) assert await system.register_analysis_module_type(amt) == amt root = system.new_root() test_observable = root.add_observable("test", "test") root_request = root.create_analysis_request() await system.process_analysis_request(root_request) # we should have a single work entry in the work queue assert await system.get_queue_size(amt) == 1 # request should be deleted assert await system.get_analysis_request(root_request.id) is None request = await system.get_next_analysis_request(OWNER_UUID, amt, 0) request.initialize_result() request.modified_observable.add_analysis(type=amt, details={"Hello": "World"}) await system.process_analysis_request( AnalysisRequest.from_dict(request.to_dict(), system)) # this analysis result for this observable should be cached now assert await system.get_cached_analysis_result(request.observable, request.type) is not None # request should be deleted assert await system.get_analysis_request(request.id) is None # work queue should be empty assert await system.get_queue_size(amt) == 0 # make another request for the same observable root = system.new_root() test_observable = root.add_observable("test", "test") root_request = root.create_analysis_request() await system.process_analysis_request(root_request) # request should be deleted assert await system.get_analysis_request(root_request.id) is None # work queue should be empty since the result was pulled from cache assert await system.get_queue_size(amt) == 0 # get the root analysis and ensure this observable has the analysis now root = await system.get_root_analysis(root.uuid) assert root is not None observable = root.get_observable(request.observable) assert observable is not None analysis = observable.get_analysis(request.type) assert analysis is not None assert analysis.root == root assert analysis.observable == request.observable assert await analysis.get_details( ) == await request.modified_observable.get_analysis(amt).get_details()
async def api_process_analysis_request(request: AnalysisRequestModel): try: await app.state.system.process_analysis_request( AnalysisRequest.from_dict(request.dict(), app.state.system)) return Response(status_code=200) except ACEError as e: return JSONResponse(status_code=400, content=ErrorModel(code=e.code, details=str(e)).dict())
async def test_EVENT_AR_NEW(system): handler = TestEventHandler() await system.register_event_handler(EVENT_AR_NEW, handler) root = system.new_root() request = root.create_analysis_request() await system.track_analysis_request(request) await handler.wait() assert handler.event.name == EVENT_AR_NEW assert AnalysisRequest.from_dict(handler.event.args, system) == request handler = TestEventHandler() await system.register_event_handler(EVENT_AR_NEW, handler) await system.track_analysis_request(request) # you can re-track a request without harm await handler.wait() assert handler.event.name == EVENT_AR_NEW assert AnalysisRequest.from_dict(handler.event.args, system) == request
async def test_EVENT_PROCESSING(system): root_handler = TestEventHandler() await system.register_event_handler(EVENT_PROCESSING_REQUEST_ROOT, root_handler) observable_request_handler = TestEventHandler() await system.register_event_handler(EVENT_PROCESSING_REQUEST_OBSERVABLE, observable_request_handler) amt = AnalysisModuleType("test", "", cache_ttl=60) await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", "test") root_request = root.create_analysis_request() await system.process_analysis_request(root_request) await root_handler.wait() assert root_handler.event.name == EVENT_PROCESSING_REQUEST_ROOT assert AnalysisRequest.from_dict(root_handler.event.args, system) == root_request request = await system.get_next_analysis_request("owner", amt, 0) await observable_request_handler.wait() assert observable_request_handler.event.name == EVENT_PROCESSING_REQUEST_OBSERVABLE assert AnalysisRequest.from_dict(observable_request_handler.event.args, system) == request result_handler = TestEventHandler() await system.register_event_handler(EVENT_PROCESSING_REQUEST_RESULT, result_handler) request.initialize_result() request.modified_observable.add_analysis(type=amt, details={"test": "test"}) await system.process_analysis_request(request) await result_handler.wait() assert result_handler.event.name == EVENT_PROCESSING_REQUEST_RESULT assert AnalysisRequest.from_dict(result_handler.event.args, system) == request
async def test_get_next_analysis_request_deleted(system): amt = AnalysisModuleType("test", "") await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", TEST_1) request = AnalysisRequest(system, root, observable, amt) await system.queue_analysis_request(request) await system.delete_analysis_request(request) assert await system.get_analysis_request_by_request_id(request.id) is None # should be nothing there to get since the request was deleted assert await system.get_next_analysis_request("owner", amt, 0) is None
async def i_get_analysis_request_by_request_id( self, key: str) -> Union[AnalysisRequest, None]: async with self.get_db() as db: result = (await db.execute( select(AnalysisRequestTracking).where( AnalysisRequestTracking.id == key))).one_or_none() if result is None: return None return AnalysisRequest.from_dict(json.loads(result[0].json_data), self)
async def test_get_next_analysis_request_by_name(system): await system.register_analysis_module_type(amt_1) root = system.new_root() observable = root.add_observable("test", TEST_1) request = AnalysisRequest(system, root, observable, amt_1) await system.queue_analysis_request(request) next_ar = await system.get_next_analysis_request(TEST_OWNER, "test", 0, version="1.0.0") assert next_ar == request assert next_ar.status == TRACKING_STATUS_ANALYZING assert next_ar.owner == TEST_OWNER assert await system.get_next_analysis_request(TEST_OWNER, "test", 0, version="1.0.0") is None
async def i_get_work(self, amt: str, timeout: float) -> Union[AnalysisRequest, None]: async with self.get_redis_connection() as rc: if not await rc.hexists(KEY_WORK_QUEUES, amt): raise UnknownAnalysisModuleTypeError() # if we're not looking to wait then we use LPOP # this always returns a single result if timeout == 0: result = await rc.lpop(get_queue_name(amt)) if result is None: return None return AnalysisRequest.from_json(result.decode(), system=self) else: # if we have a timeout when we use BLPOP result = await rc.blpop(get_queue_name(amt), timeout=timeout) if result is None: return None # this can return a tuple of results (key, item1, item2, ...) _, result = result return AnalysisRequest.from_json(result.decode(), system=self)
async def test_EVENT_WORK_ASSIGNED(system): handler = TestEventHandler() await system.register_event_handler(EVENT_WORK_ASSIGNED, handler) amt = AnalysisModuleType("test", "", cache_ttl=60) await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", "test") root_request = root.create_analysis_request() await system.process_analysis_request(root_request) request = await system.get_next_analysis_request("owner", amt, 0) await handler.wait() assert handler.event.name == EVENT_WORK_ASSIGNED assert AnalysisRequest.from_dict(handler.event.args, system) == request
async def test_EVENT_WORK_REMOVE(system): handler = TestEventHandler() await system.register_event_handler(EVENT_WORK_REMOVE, handler) amt = AnalysisModuleType("test", "") await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", "test") request = AnalysisRequest(system, root, observable, amt) await system.queue_analysis_request(request) work = await system.get_work(amt, 0) await handler.wait() assert handler.event.name == EVENT_WORK_REMOVE assert handler.event.args[0] == amt.name assert AnalysisRequest.from_dict(handler.event.args[1], system) == work # can't get fired if you ain't got no work handler = TestEventHandler() await system.register_event_handler(EVENT_WORK_REMOVE, handler) work = await system.get_work(amt, 0) assert handler.event is None
async def i_get_cached_analysis_result( self, cache_key: str) -> Union[AnalysisRequest, None]: async with self.get_db() as db: result = (await db.execute( select(AnalysisResultCache).where( AnalysisResultCache.cache_key == cache_key) )).one_or_none() if result is None: return None result = result[0] if result.expiration_date is not None and utc_now( ) > result.expiration_date: return None return AnalysisRequest.from_json(result.json_data, system=self)
async def i_process_expired_analysis_requests( self, amt: AnalysisModuleType) -> int: assert isinstance(amt, AnalysisModuleType) async with self.get_db() as db: for db_request in await db.execute( select(AnalysisRequestTracking).where( and_( AnalysisRequestTracking.analysis_module_type == amt.name, datetime.datetime.now() > AnalysisRequestTracking.expiration_date, ))): request = AnalysisRequest.from_json(db_request[0].json_data, self) await self.fire_event(EVENT_AR_EXPIRED, request) try: await self.queue_analysis_request(request) except UnknownAnalysisModuleTypeError: self.delete_analysis_request(request)
async def i_track_analysis_request(self, request: AnalysisRequest): # XXX we're using server-side time instead of database time expiration_date = None if request.status == TRACKING_STATUS_ANALYZING: expiration_date = datetime.datetime.now() + datetime.timedelta( request.type.timeout) db_request = AnalysisRequestTracking( id=request.id, expiration_date=expiration_date, analysis_module_type=request.type.name if request.type else None, cache_key=request.cache_key, root_uuid=request.root.uuid, json_data=request.to_json(), ) async with self.get_db() as db: await db.merge(db_request) await db.commit()
async def i_get_linked_analysis_requests( self, source: AnalysisRequest) -> list[AnalysisRequest]: async with self.get_db() as db: source_request = ( # NOTE you cannot do lazy loading with async in sqlalchemy 1.4 (await db.execute( select(AnalysisRequestTracking).options( selectinload( AnalysisRequestTracking.linked_requests)).where( AnalysisRequestTracking.id == source.id)) ).one_or_none()) if source_request is None: return None # I think this is where you have to be careful with async return [ AnalysisRequest.from_dict(json.loads(_.json_data), self) for _ in source_request[0].linked_requests ]
async def i_cache_analysis_result(self, cache_key: str, request: AnalysisRequest, expiration: Optional[int]) -> str: expiration_date = None # XXX using system side time if expiration is not None: expiration_date = utc_now() + datetime.timedelta( seconds=expiration) cache_result = AnalysisResultCache( cache_key=cache_key, expiration_date=expiration_date, analysis_module_type=request.type.name, json_data=request.to_json(), ) async with self.get_db() as db: await db.merge(cache_result) await db.commit() return cache_key
async def test_EVENT_AR_EXPIRED(system): handler = TestEventHandler() await system.register_event_handler(EVENT_AR_EXPIRED, handler) amt = AnalysisModuleType(name="test", description="test", version="1.0.0", timeout=0, cache_ttl=600) await system.register_analysis_module_type(amt) root = system.new_root() observable = root.add_observable("test", "test") request = observable.create_analysis_request(amt) await system.track_analysis_request(request) request.status = TRACKING_STATUS_ANALYZING await system.track_analysis_request(request) await system.process_expired_analysis_requests(amt) await handler.wait() assert handler.event.name == EVENT_AR_EXPIRED assert AnalysisRequest.from_dict(handler.event.args, system) == request
async def process_analysis_request(self, ar: AnalysisRequest): async with self.get_client() as client: response = await client.post("/process_request", json=ar.to_dict()) _raise_exception_on_error(response)
async def i_put_work(self, amt: str, analysis_request: AnalysisRequest): async with self.get_redis_connection() as rc: if not await rc.hexists(KEY_WORK_QUEUES, amt): raise UnknownAnalysisModuleTypeError() await rc.rpush(get_queue_name(amt), analysis_request.to_json())