def _handle_testssl_push(self, worker_id, job_id, testssl_data) -> bool: data_type = testssl_data["type"] handler = { "tls_versions": self._handle_testssl_tls_versions_push, "cipherlists": self._handle_testssl_cipherlists_push, "server_cipher_order": self._handle_testssl_server_cipher_order_push, "cipher_info": self._handle_testssl_cipher_info_push, "certificate": self._handle_testssl_certificate_push, }.get(data_type, None) if handler is None: raise RuntimeError( "unhandled testssl push data type: {!r}".format(data_type)) with model.session_scope(self._sessionmaker) as session: try: task = session.query(model.PendingScanTask).filter( model.PendingScanTask.id_ == job_id, ).one() except sqlalchemy.orm.exc.NoResultFound: return False if task.assigned_worker != worker_id: # late worker? return False if not handler(session, task.scan_id, testssl_data): # allow immediate rescheduling of the task task.heartbeat = None return False task.heartbeat = datetime.utcnow() session.commit() return True
def _collect_tasks(self): with model.session_scope(self._sessionmaker) as session: ep_tasks = session.query(model.PendingScanTask.id_).filter( model.PendingScanTask.type_ == model.TaskType.DISCOVER_ENDPOINTS) for task_id, in ep_tasks: self._task_queue.push(self._processor._discover_endpoints, task_id)
async def _discover_tlsa(self, task_id): with model.session_scope(self._sessionmaker) as session: taskq = session.query(model.PendingScanTask).filter( model.PendingScanTask.id_ == task_id) try: task = taskq.one() except sqlalchemy.orm.exc.NoResultFound: # task has been done by another worker already. return taskq.delete() session.commit()
def _handle_testssl_result(self, worker_id, job_id, result): with model.session_scope(self._sessionmaker) as session: try: task = session.query(model.PendingScanTask).filter( model.PendingScanTask.id_ == job_id, ).one() except sqlalchemy.orm.exc.NoResultFound: return if task.assigned_worker != worker_id: # late worker? return scan_id = task.scan_id self._handle_testssl_tls_versions_push( session, scan_id, result, ) self._handle_testssl_certificate_push( session, scan_id, result, ) self._handle_testssl_server_cipher_order_push( session, scan_id, result, ) for cipher_info in result["ciphers"]: self._handle_testssl_cipher_info_push( session, scan_id, {"cipher": cipher_info}, ) self._handle_testssl_cipherlists_complete( session, scan_id, result["cipherlists"], ) session.delete(task) session.commit()
async def _handle_message(self, msg): if msg["type"] == coordinator_api.RequestType.PING.value: return coordinator_api.mkv1response( coordinator_api.ResponseType.PONG, msg["payload"], ) elif msg["type"] == coordinator_api.RequestType.SCAN_DOMAIN.value: now = datetime.utcnow() cutoff = ( now - timedelta(seconds=self._scan_ratelimit_unprivileged.interval)) with model.session_scope(self._sessionmaker) as session: nscans = session.query(model.Scan.created_at).filter( model.Scan.created_at >= cutoff, ).limit( self._scan_ratelimit_unprivileged.burst).count() if nscans >= self._scan_ratelimit_unprivileged.burst: return coordinator_api.mkv1response( coordinator_api.ResponseType.ERROR, common_api.mkerror( common_api.ErrorCode.TOO_MANY_REQUESTS, "unprivileged rate limit hit", )) scan = model.Scan() # TODO: IDNA and stuff scan.domain = msg["payload"]["domain"].encode("utf-8") scan.created_at = now scan.protocol = model.ScanType(msg["payload"]["protocol"]) scan.state = model.ScanState.IN_PROGRESS scan.privileged = False session.add(scan) ep_task = model.PendingScanTask() ep_task.scan = scan ep_task.type_ = model.TaskType.DISCOVER_ENDPOINTS ep_task.parameters = "{}".encode("utf-8") session.add(ep_task) session.commit() self._task_queue.push(self._discover_endpoints, ep_task.id_) return coordinator_api.mkv1response( coordinator_api.ResponseType.SCAN_QUEUED, { "scan_id": scan.id_, }, ) elif msg["type"] == coordinator_api.RequestType.GET_TESTSSL_JOB.value: cutoff = datetime.utcnow() - HEARTBEAT_THRESOHLD worker_id = msg["payload"]["worker_id"] with model.session_scope(self._sessionmaker) as session: task = session.query(model.PendingScanTask).filter( model.PendingScanTask.type_ == model.TaskType.TLS_SCAN, sqlalchemy.or_( model.PendingScanTask.heartbeat == None, # NOQA model.PendingScanTask.heartbeat < cutoff, )).order_by(model.PendingScanTask.heartbeat.asc()).limit( 1).one_or_none() if task is None: return coordinator_api.mkv1response( coordinator_api.ResponseType.NO_TASKS, { "ask_again_after": random.randint(1, 3), }, ) scan = task.scan job = { "job_id": str(task.id_), "domain": scan.domain.decode("utf-8"), "hostname": scan.primary_host.decode("ascii"), "port": scan.primary_port, "protocol": scan.protocol.value, "tls_mode": scan.primary_tls_mode.value, } task.assigned_worker = worker_id task.heartbeat = datetime.utcnow() session.commit() return coordinator_api.mkv1response( coordinator_api.ResponseType.GET_TESTSSL_JOB, job, ) elif msg[ "type"] == coordinator_api.RequestType.TESTSSL_RESULT_PUSH.value: job_id = int(msg["payload"]["job_id"]) worker_id = msg["payload"]["worker_id"] data = msg["payload"]["testssl_data"] return coordinator_api.mkv1response( coordinator_api.ResponseType.JOB_CONFIRMATION, { "continue": self._handle_testssl_push( worker_id, job_id, data) }) elif msg["type"] == coordinator_api.RequestType.TESTSSL_COMPLETE.value: job_id = int(msg["payload"]["job_id"]) worker_id = msg["payload"]["worker_id"] result = msg["payload"]["testssl_result"] self._handle_testssl_result( worker_id, job_id, result, ) return coordinator_api.mkv1response( coordinator_api.ResponseType.OK, {}) else: return coordinator_api.mkv1response( coordinator_api.ResponseType.ERROR, common_api.mkerror( # not BAD_REQUEST here, because the type was validated # earlier common_api.ErrorCode.INTERNAL_ERROR, "unhandled type", ))
async def _discover_endpoints(self, task_id): with model.session_scope(self._sessionmaker) as session: task = session.query(model.PendingScanTask).filter( model.PendingScanTask.id_ == task_id).one() scan = task.scan scan_id = scan.id_ scan_domain = scan.domain scan_protocol = scan.protocol srv_services = { model.ScanType.C2S: ["xmpp-client", "xmpps-client"], model.ScanType.S2S: ["xmpp-server", "xmpps-server"], }[scan_protocol] db_records = [] async for service, record in gather_srv_records( scan_domain, srv_services): db_record = model.SRVRecord() db_record.scan_id = scan_id db_record.service = service db_record.protocol = "tcp" db_record.weight = record.weight db_record.port = record.port db_record.priority = record.priority db_record.host = record.target.to_text().encode("ascii") db_records.append(db_record) if db_records: db_records.sort(key=lambda x: (x.priority, -x.weight)) primary_record = db_records[0] if primary_record.service in ["xmpps-server", "xmpps-client"]: primary_tls_mode = model.TLSMode.DIRECT else: primary_tls_mode = model.TLSMode.STARTTLS primary_host = primary_record.host primary_port = primary_record.port else: # fallback to A/AAAA for endpoint selection primary_host = scan_domain primary_port = { model.ScanType.C2S: 5222, model.ScanType.S2S: 5269, }[scan_protocol] primary_tls_mode = model.TLSMode.STARTTLS with model.session_scope(self._sessionmaker) as session: taskq = session.query(model.PendingScanTask).filter( model.PendingScanTask.id_ == task_id) try: task = taskq.one() except sqlalchemy.orm.exc.NoResultFound: # task has been done by another worker already. return taskq.delete() scan = session.query(model.Scan).filter( model.Scan.id_ == task.scan_id, ).one() scan.primary_host = primary_host scan.primary_port = primary_port scan.primary_tls_mode = primary_tls_mode session.query(model.SRVRecord).filter( model.SRVRecord.scan_id == scan_id, ).delete() for db_record in db_records: session.add(db_record) tls_task = model.PendingScanTask() tls_task.scan_id = scan_id tls_task.type_ = model.TaskType.TLS_SCAN tls_task.parameters = '{}' session.add(tls_task) session.commit()