async def _set_job(self, command, socket): ''' Create or update job metadata. ''' if command.HasField('tag_list'): tags = [t.strip() for t in command.tag_list.tags] else: tags = None if command.HasField('job_id'): # Update state of existing job. job_id = str(UUID(bytes=command.job_id)) name = command.name if command.HasField('name') else None await self._crawl_manager.update_job(job_id, name, tags) if command.HasField('run_state'): run_state = command.run_state if run_state == protobuf.shared_pb2.CANCELLED: await self._crawl_manager.cancel_job(job_id) elif run_state == protobuf.shared_pb2.PAUSED: await self._crawl_manager.pause_job(job_id) elif run_state == protobuf.shared_pb2.RUNNING: await self._crawl_manager.resume_job(job_id) else: raise Exception( 'Not allowed to set job run state: {}'.format( run_state)) response = Response() else: # Create new job. name = command.name policy_id = str(UUID(bytes=command.policy_id)) seeds = [s.strip() for s in command.seeds] tags = tags or [] if name.strip() == '': url = urlparse(seeds[0]) name = url.hostname if len(seeds) > 1: name += '& {} more'.format(len(seeds) - 1) if command.run_state != protobuf.shared_pb2.RUNNING: raise InvalidRequestException( 'New job state must be set to RUNNING') job_id = await self._crawl_manager.start_job( seeds, policy_id, name, tags) response = Response() response.new_job.job_id = UUID(job_id).bytes return response
async def _get_policy(self, command, socket): ''' Get a single policy. ''' policy_id = str(UUID(bytes=command.policy_id)) policy_doc = await self._policy_manager.get_policy(policy_id) response = Response() Policy.convert_doc_to_pb(policy_doc, response.policy) return response
async def _get_job(self, command, socket): ''' Get status for a single job. ''' job_id = str(UUID(bytes=command.job_id)) job_doc = await self._crawl_manager.get_job(job_id) response = Response() if job_doc is None: response.is_success = False response.error_message = f'No job exists with ID={job_id}' else: job = response.job job.job_id = UUID(job_doc['id']).bytes for seed in job_doc['seeds']: job.seeds.append(seed) for tag in job_doc['tags']: job.tag_list.tags.append(tag) Policy.convert_doc_to_pb(job_doc['policy'], job.policy) job.name = job_doc['name'] job.item_count = job_doc['item_count'] job.http_success_count = job_doc['http_success_count'] job.http_error_count = job_doc['http_error_count'] job.exception_count = job_doc['exception_count'] job.started_at = job_doc['started_at'].isoformat() if job_doc['completed_at'] is not None: job.completed_at = job_doc['completed_at'].isoformat() run_state = job_doc['run_state'].upper() job.run_state = protobuf.shared_pb2.JobRunState.Value(run_state) http_status_counts = job_doc['http_status_counts'] for status_code, count in http_status_counts.items(): job.http_status_counts[int(status_code)] = count return response
async def _get_domain_login(self, command, socket): ''' Get a domain login. ''' if not command.HasField('domain'): raise InvalidRequestException('domain is required.') domain = command.domain async with self._db_pool.connection() as conn: count = await r.table('domain_login').count().run(conn) domain_login = await ( r.table('domain_login').get(domain).run(conn)) if domain_login is None: raise InvalidRequestException('No domain credentials found for' ' domain={}'.format(domain)) response = Response() response.domain_login.domain = domain_login['domain'] response.domain_login.login_url = domain_login['login_url'] if domain_login['login_test'] is not None: response.domain_login.login_test = domain_login['login_test'] response.domain_login.auth_count = len(domain_login['users']) for user in domain_login['users']: dl_user = response.domain_login.users.add() dl_user.username = user['username'] dl_user.password = user['password'] dl_user.working = user['working'] return response
async def _subscribe_job_status(self, command, socket): ''' Handle the subscribe crawl status command. ''' subscription = JobStatusSubscription(self._tracker, socket, command.min_interval) self._subscription_manager.add(subscription) response = Response() response.new_subscription.subscription_id = subscription.get_id() return response
async def _set_job_schedule(self, command, socket): ''' Create or update job schedule metadata. ''' doc = Scheduler.pb_to_doc(command.job_schedule) schedule_id = await self._scheduler.set_job_schedule(doc) response = Response() if schedule_id is not None: response.new_job_schedule.schedule_id = UUID(schedule_id).bytes return response
async def _subscribe_task_monitor(self, command, socket): ''' Handle the subscribe task monitor command. ''' subscription = TaskMonitorSubscription(socket, command.period, command.top_n) self._subscription_manager.add(subscription) response = Response() response.new_subscription.subscription_id = subscription.get_id() return response
async def _ping(self, command, socket): ''' A client may ping the server to prevent connection timeout. This sends back whatever string was sent. ''' response = Response() response.ping.pong = command.pong return response
async def _subscribe_resource_monitor(self, command, socket): ''' Handle the subscribe resource monitor command. ''' subscription = ResourceMonitorSubscription(socket, self._resource_monitor, command.history) self._subscription_manager.add(subscription) response = Response() response.new_subscription.subscription_id = subscription.get_id() return response
async def _set_rate_limit(self, command, socket): ''' Set a rate limit. ''' rate_limit = command.rate_limit delay = rate_limit.delay if rate_limit.HasField('delay') else None if rate_limit.HasField('domain'): await self._rate_limiter.set_domain_limit(rate_limit.domain, delay) else: await self._rate_limiter.set_global_limit(delay) return Response()
async def _delete_domain_login(self, command, socket): ''' Delete a domain login and all of its users. ''' if command.HasField('domain'): domain = command.domain else: raise InvalidRequestException('domain is required.') async with self._db_pool.connection() as conn: await (r.table('domain_login').get(domain).delete().run(conn)) return Response()
async def _get_job_schedule(self, command, socket): ''' Get metadata for a job schedule. ''' schedule_id = str(UUID(bytes=command.schedule_id)) doc = await self._scheduler.get_job_schedule(schedule_id) response = Response() if doc is None: response.is_success = False response.error_message = f'No schedule exists with ID={schedule_id}' else: pb = response.job_schedule Scheduler.doc_to_pb(doc, pb) return response
async def _handle_request(self, client_ip, websocket, request_data): ''' Handle a single request/response pair. ''' request = Request.FromString(request_data) start = time() try: command_name = request.WhichOneof('Command') if command_name is None: raise InvalidRequestException('No command specified') command = getattr(request, command_name) try: handler = self._request_handlers[command_name] except KeyError: raise InvalidRequestException( 'Invalid command name: {}'.format(command_name)) response = await handler(command, websocket) response.request_id = request.request_id response.is_success = True elapsed = time() - start logger.info('Request OK %s %s %0.3fs', command_name, client_ip, elapsed) except asyncio.CancelledError: raise except Exception as e: if isinstance(e, InvalidRequestException): elapsed = time() - start logger.error('Request ERROR %s %s %0.3fs', command_name, client_ip, elapsed) else: logger.exception('Exception while handling request:\n%r', request) response = Response() response.is_success = False response.error_message = str(e) try: response.request_id = request.request_id except: # A parsing failure could lead to request or request_id not # being defined. There's nothing we can do to fix this. pass if response.IsInitialized(): message = ServerMessage() message.response.MergeFrom(response) message_data = message.SerializeToString() await websocket.send(message_data) else: # This could happen, e.g. if the request_id is not set. logger.error('Cannot send uninitialized response:\n%r', response)
async def _list_job_schedules(self, command, socket): ''' Return a list of job schedules. ''' limit = command.page.limit offset = command.page.offset count, schedules = await self._scheduler.list_job_schedules( limit, offset) response = Response() response.list_jobs.total = count for doc in schedules: pb = response.list_job_schedules.job_schedules.add() Scheduler.doc_to_pb(doc, pb) return response
async def _set_policy(self, command, socket): ''' Create or update a single policy. If the policy ID is set, then update the corresponding policy. Otherwise, create a new policy. ''' policy_doc = Policy.convert_pb_to_doc(command.policy) policy_id = await self._policy_manager.set_policy(policy_doc) response = Response() if policy_id is not None: response.new_policy.policy_id = UUID(policy_id).bytes return response
async def _get_captcha_solver(self, command, socket): ''' Get a CAPTCHA solver. ''' if not command.HasField('solver_id'): raise InvalidRequestException('solver_id is required.') solver_id = str(UUID(bytes=command.solver_id)) async with self._db_pool.connection() as conn: doc = await r.table('captcha_solver').get(solver_id).run(conn) if doc is None: raise InvalidRequestException( 'No CAPTCHA solver found for that ID') response = Response() response.solver.CopyFrom(captcha_doc_to_pb(doc)) return response
async def _set_captcha_solver(self, command, socket): ''' Create or update CAPTCHA solver. ''' now = datetime.now(tzlocal()) doc = captcha_pb_to_doc(command.solver) doc['updated_at'] = now response = Response() async with self._db_pool.connection() as conn: if command.solver.HasField('solver_id'): await r.table('captcha_solver').update(doc).run(conn) else: doc['created_at'] = now result = await r.table('captcha_solver').insert(doc).run(conn) solver_id = result['generated_keys'][0] response.new_solver.solver_id = UUID(solver_id).bytes return response
async def _get_rate_limits(self, command, socket): ''' Get a page of rate limits. ''' limit = command.page.limit offset = command.page.offset count, rate_limits = await self._rate_limiter.get_limits(limit, offset) response = Response() response.list_rate_limits.total = count for rate_limit in rate_limits: rl = response.list_rate_limits.rate_limits.add() rl.name = rate_limit['name'] rl.delay = rate_limit['delay'] if rate_limit['type'] == 'domain': rl.domain = rate_limit['domain'] return response
async def _list_captcha_solvers(self, command, socket): ''' Return a list of CAPTCHA solvers. ''' limit = command.page.limit offset = command.page.offset response = Response() solvers = list() async with self._db_pool.connection() as conn: count = await r.table('captcha_solver').count().run(conn) docs = await (r.table('captcha_solver').order_by('name').skip( offset).limit(limit).run(conn)) for doc in docs: solver = response.list_captcha_solvers.solvers.add() solver.CopyFrom(captcha_doc_to_pb(doc)) response.list_captcha_solvers.total = count return response
async def _list_policies(self, command, socket): ''' Get a list of policies. ''' limit = command.page.limit offset = command.page.offset count, policies = await self._policy_manager.list_policies( limit, offset) response = Response() response.list_policies.total = count for policy_doc in policies: policy = response.list_policies.policies.add() policy.policy_id = UUID(policy_doc['id']).bytes policy.name = policy_doc['name'] policy.created_at = policy_doc['created_at'].isoformat() policy.updated_at = policy_doc['updated_at'].isoformat() return response
async def _subscribe_crawl_sync(self, command, socket): ''' Handle the subscribe crawl items command. ''' job_id = str(UUID(bytes=command.job_id)) compression_ok = command.compression_ok if command.HasField('sync_token'): sync_token = command.sync_token else: sync_token = None subscription = CrawlSyncSubscription(self._tracker, self._db_pool, socket, job_id, compression_ok, sync_token) self._subscription_manager.add(subscription) response = Response() response.new_subscription.subscription_id = subscription.get_id() return response
async def _set_domain_login(self, command, socket): ''' Create or update a domain login. ''' domain_login = command.login if not domain_login.HasField('domain'): raise InvalidRequestException('domain is required.') domain = domain_login.domain async with self._db_pool.connection() as conn: doc = await (r.table('domain_login').get(domain).run(conn)) if doc is None: if not domain_login.HasField('login_url'): raise InvalidRequestException('login_url is required to' ' create a domain login.') doc = { 'domain': domain, 'login_url': domain_login.login_url, 'login_test': None, } if domain_login.HasField('login_url'): doc['login_url'] = domain_login.login_url if domain_login.HasField('login_test'): doc['login_test'] = domain_login.login_test doc['users'] = list() for user in domain_login.users: doc['users'].append({ 'username': user.username, 'password': user.password, 'working': user.working, }) async with self._db_pool.connection() as conn: # replace() is supposed to upsert, but for some reason it doesn't, # so I'm calling insert() explicitly. response = await (r.table('domain_login').replace(doc).run(conn)) if response['replaced'] == 0: await (r.table('domain_login').insert(doc).run(conn)) return Response()
async def _get_job_items(self, command, socket): ''' Get a page of items (crawl responses) from a job. ''' job_id = str(UUID(bytes=command.job_id)) limit = command.page.limit offset = command.page.offset total_items, item_docs = await self._crawl_manager.get_job_items( job_id, command.include_success, command.include_error, command.include_exception, limit, offset) response = Response() response.list_items.total = total_items compression_ok = command.compression_ok for item_doc in item_docs: item = response.list_items.items.add() if item_doc['join'] is None: item.is_body_compressed = False elif item_doc['join']['is_compressed'] and not compression_ok: item.body = gzip.decompress(item_doc['join']['body']) item.is_body_compressed = False else: item.body = item_doc['join']['body'] item.is_body_compressed = item_doc['join']['is_compressed'] if 'content_type' in item_doc: item.content_type = item_doc['content_type'] if 'exception' in item_doc: item.exception = item_doc['exception'] if 'status_code' in item_doc: item.status_code = item_doc['status_code'] header_iter = iter(item_doc.get('headers', [])) for key in header_iter: value = next(header_iter) header = item.headers.add() header.key = key header.value = value item.cost = item_doc['cost'] item.job_id = UUID(item_doc['job_id']).bytes item.completed_at = item_doc['completed_at'].isoformat() item.started_at = item_doc['started_at'].isoformat() item.duration = item_doc['duration'] item.url = item_doc['url'] item.url_can = item_doc['url_can'] item.is_success = item_doc['is_success'] return response
async def _profile(self, command, socket): ''' Run CPU profiler. ''' profile = cProfile.Profile() profile.enable() await asyncio.sleep(command.duration) profile.disable() # pstats sorting only works when you use pstats printing... so we have # to build our own data structure in order to sort it. pr_stats = pstats.Stats(profile) stats = list() for key, value in pr_stats.stats.items(): stats.append({ 'file': key[0], 'line_number': key[1], 'function': key[2], 'calls': value[0], 'non_recursive_calls': value[1], 'total_time': value[2], 'cumulative_time': value[3], }) try: stats.sort(key=operator.itemgetter(command.sort_by), reverse=True) except KeyError: raise InvalidRequestException('Invalid sort key: {}'.format( command.sort_by)) response = Response() response.performance_profile.total_calls = pr_stats.total_calls response.performance_profile.total_time = pr_stats.total_tt for stat in stats[:command.top_n]: function = response.performance_profile.functions.add() function.file = stat['file'] function.line_number = stat['line_number'] function.function = stat['function'] function.calls = stat['calls'] function.non_recursive_calls = stat['non_recursive_calls'] function.total_time = stat['total_time'] function.cumulative_time = stat['cumulative_time'] return response
async def _delete_captcha_solver(self, command, socket): ''' Delete a a CAPTCHA solver. ''' if command.HasField('solver_id'): solver_id = str(UUID(bytes=command.solver_id)) else: raise InvalidRequestException('solver_id is required.') response = Response() async with self._db_pool.connection() as conn: use_count = await (r.table('policy').filter({ 'captcha_solver_id': solver_id }).count().run(conn)) if use_count > 0: raise InvalidRequestException( 'Cannot delete CAPTCHA solver' ' because it is being used by a policy.') await (r.table('captcha_solver').get(solver_id).delete().run(conn)) return response
async def _list_jobs(self, command, socket): ''' Return a list of jobs. ''' limit = command.page.limit offset = command.page.offset if command.HasField('started_after'): started_after = dateutil.parser.parse(command.started_after) else: started_after = None tag = command.tag if command.HasField('tag') else None schedule_id = str(UUID(bytes=command.schedule_id)) if \ command.HasField('schedule_id') else None count, job_docs = await self._crawl_manager.list_jobs( limit, offset, started_after, tag, schedule_id) response = Response() response.list_jobs.total = count for job_doc in job_docs: job = response.list_jobs.jobs.add() job.job_id = UUID(job_doc['id']).bytes job.name = job_doc['name'] for seed in job_doc['seeds']: job.seeds.append(seed) for tag in job_doc['tags']: job.tag_list.tags.append(tag) job.item_count = job_doc['item_count'] job.http_success_count = job_doc['http_success_count'] job.http_error_count = job_doc['http_error_count'] job.exception_count = job_doc['exception_count'] job.started_at = job_doc['started_at'].isoformat() if job_doc['completed_at'] is not None: job.completed_at = job_doc['completed_at'].isoformat() run_state = job_doc['run_state'].upper() job.run_state = protobuf.shared_pb2.JobRunState \ .Value(run_state) http_status_counts = job_doc['http_status_counts'] for status_code, count in http_status_counts.items(): job.http_status_counts[int(status_code)] = count return response
async def _list_domain_logins(self, command, socket): ''' Return a list of domain logins. ''' limit = command.page.limit skip = command.page.offset async with self._db_pool.connection() as conn: count = await r.table('domain_login').count().run(conn) cursor = await (r.table('domain_login').order_by( index='domain').skip(skip).limit(limit).run(conn)) response = Response() response.list_domain_logins.total = count async for domain_doc in cursor: dl = response.list_domain_logins.logins.add() dl.domain = domain_doc['domain'] dl.login_url = domain_doc['login_url'] if domain_doc['login_test'] is not None: dl.login_test = domain_doc['login_test'] # Not very efficient way to count users, but don't know a better # way... dl.auth_count = len(domain_doc['users']) return response
async def _unsubscribe(self, command, socket): ''' Handle an unsubscribe command. ''' sub_id = command.subscription_id await self._subscription_manager.unsubscribe(socket, sub_id) return Response()
async def _delete_job_schedule(self, command, socket): ''' Delete a job schedule. ''' schedule_id = str(UUID(bytes=command.schedule_id)) await self._scheduler.delete_job_schedule(schedule_id) return Response()
async def _delete_policy(self, command, socket): ''' Delete a policy. ''' policy_id = str(UUID(bytes=command.policy_id)) await self._policy_manager.delete_policy(policy_id) return Response()