def step(self, name: str): assert name not in self.timings d: Dict[str, int] = dict() self.timings[name] = d d['start_time'] = time_msecs() yield d['finish_time'] = time_msecs() d['duration'] = d['finish_time'] - d['start_time']
async def post_job_complete_1(self, job, run_duration): full_status = await retry_all_errors( f'error while getting status for {job}')(job.status) if job.format_version.has_full_status_in_gcs(): await retry_all_errors( f'error while writing status file to gcs for {job}')( self.log_store.write_status_file, job.batch_id, job.job_id, job.attempt_id, json.dumps(full_status)) db_status = job.format_version.db_status(full_status) status = { 'batch_id': full_status['batch_id'], 'job_id': full_status['job_id'], 'attempt_id': full_status['attempt_id'], 'state': full_status['state'], 'start_time': full_status['start_time'], 'end_time': full_status['end_time'], 'status': db_status } body = {'status': status} start_time = time_msecs() delay_secs = 0.1 while True: try: async with aiohttp.ClientSession( raise_for_status=True, timeout=aiohttp.ClientTimeout(total=5)) as session: await session.post(deploy_config.url( 'batch-driver', '/api/v1alpha/instances/job_complete'), json=body, headers=self.headers) return except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as e: if isinstance(e, aiohttp.ClientResponseError) and e.status == 404: # pylint: disable=no-member raise log.exception(f'failed to mark {job} complete, retrying') # unlist job after 3m or half the run duration now = time_msecs() elapsed = now - start_time if (job.id in self.jobs and elapsed > 180 * 1000 and elapsed > run_duration / 2): log.info( f'too much time elapsed marking {job} complete, removing from jobs, will keep retrying' ) del self.jobs[job.id] self.last_updated = time_msecs() await asyncio.sleep(delay_secs * random.uniform(0.7, 1.3)) # exponentially back off, up to (expected) max of 2m delay_secs = min(delay_secs * 2, 2 * 60.0)
async def run(self): app_runner = None site = None try: app = web.Application(client_max_size=HTTP_CLIENT_MAX_SIZE) app.add_routes([ web.post('/api/v1alpha/batches/jobs/create', self.create_job), web.delete( '/api/v1alpha/batches/{batch_id}/jobs/{job_id}/delete', self.delete_job), web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/log', self.get_job_log), web.get('/api/v1alpha/batches/{batch_id}/jobs/{job_id}/status', self.get_job_status), web.get('/healthcheck', self.healthcheck) ]) app_runner = web.AppRunner(app) await app_runner.setup() site = web.TCPSite(app_runner, '0.0.0.0', 5000) await site.start() try: await asyncio.wait_for(self.activate(), MAX_IDLE_TIME_MSECS / 1000) except asyncio.TimeoutError: log.exception( f'could not activate after trying for {MAX_IDLE_TIME_MSECS} ms, exiting' ) else: idle_duration = time_msecs() - self.last_updated while self.jobs or idle_duration < MAX_IDLE_TIME_MSECS: log.info( f'n_jobs {len(self.jobs)} free_cores {self.cpu_sem.value / 1000} idle {idle_duration}' ) await asyncio.sleep(15) idle_duration = time_msecs() - self.last_updated log.info(f'idle {idle_duration} ms, exiting') async with get_context_specific_ssl_client_session( raise_for_status=True, timeout=aiohttp.ClientTimeout(total=60)) as session: # Don't retry. If it doesn't go through, the driver # monitoring loops will recover. If the driver is # gone (e.g. testing a PR), this would go into an # infinite loop and the instance won't be deleted. await session.post(deploy_config.url( 'batch-driver', '/api/v1alpha/instances/deactivate'), headers=self.headers) log.info('deactivated') finally: log.info('shutting down') if site: await site.stop() log.info('stopped site') if app_runner: await app_runner.cleanup() log.info('cleaned up app runner')
async def check_on_instance(self, instance): active_and_healthy = await instance.check_is_active_and_healthy() if active_and_healthy: return try: spec = await self.compute_client.get( f'/zones/{instance.zone}/instances/{instance.name}') except aiohttp.ClientResponseError as e: if e.status == 404: await self.remove_instance(instance, 'does_not_exist') return raise # PROVISIONING, STAGING, RUNNING, STOPPING, TERMINATED gce_state = spec['status'] log.info(f'{instance} gce_state {gce_state}') if (gce_state == 'PROVISIONING' and instance.state == 'pending' and time_msecs() - instance.time_created > 5 * 60 * 1000): log.exception( f'{instance} did not provision within 5m after creation, deleting' ) await self.call_delete_instance(instance, 'activation_timeout') if gce_state in ('STOPPING', 'TERMINATED'): log.info( f'{instance} live but stopping or terminated, deactivating') await instance.deactivate('terminated') if gce_state in ('STAGING', 'RUNNING'): last_start_timestamp = spec.get('lastStartTimestamp') assert last_start_timestamp is not None, f'lastStartTimestamp does not exist {spec}' last_start_time_msecs = dateutil.parser.isoparse( last_start_timestamp).timestamp() * 1000 if instance.state == 'pending' and time_msecs( ) - last_start_time_msecs > 5 * 60 * 1000: log.exception( f'{instance} did not activate within 5m after starting, deleting' ) await self.call_delete_instance(instance, 'activation_timeout') if instance.state == 'inactive': log.info(f'{instance} is inactive, deleting') await self.call_delete_instance(instance, 'inactive') await instance.update_timestamp()
async def update_timestamp(self): now = time_msecs() await self.db.execute_update('UPDATE instances SET last_updated = %s WHERE name = %s;', (now, self.name)) self.inst_coll.adjust_for_remove_instance(self) self._last_updated = now self.inst_coll.adjust_for_add_instance(self)
async def check(instance): since_last_updated = time_msecs() - instance.last_updated if since_last_updated > 60 * 1000: log.info( f'checking on {instance}, last updated {since_last_updated / 1000}s ago' ) await self.check_on_instance(instance)
async def check_on_instance(self, instance): active_and_healthy = await instance.check_is_active_and_healthy() if active_and_healthy: return try: spec = await self.compute_client.get( f'/zones/{instance.zone}/instances/{instance.name}') except aiohttp.ClientResponseError as e: if e.status == 404: await self.remove_instance(instance, 'does_not_exist') return raise # PROVISIONING, STAGING, RUNNING, STOPPING, TERMINATED gce_state = spec['status'] log.info(f'{instance} gce_state {gce_state}') if gce_state in ('STOPPING', 'TERMINATED'): log.info( f'{instance} live but stopping or terminated, deactivating') await instance.deactivate('terminated') if (gce_state in ('STAGING', 'RUNNING') and instance.state == 'pending' and time_msecs() - instance.time_created > 5 * 60 * 1000): # FIXME shouldn't count time in PROVISIONING log.info(f'{instance} did not activate within 5m, deleting') await self.call_delete_instance(instance, 'activation_timeout') if instance.state == 'inactive': log.info(f'{instance} is inactive, deleting') await self.call_delete_instance(instance, 'inactive') await instance.update_timestamp()
async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance_name, id): try: resources = [] end_time = time_msecs() await mark_job_complete( app, batch_id, job_id, attempt_id, instance_name, 'Cancelled', None, None, end_time, 'cancelled', resources, ) instance = self.inst_coll_manager.get_instance( instance_name) if instance is None: log.warning( f'in cancel_cancelled_creating_jobs: unknown instance {instance_name}' ) return await instance.inst_coll.call_delete_instance( instance, 'cancelled') except Exception: log.info( f'cancelling creating job {id} on instance {instance_name}', exc_info=True)
async def mark_job_complete(app, batch_id, job_id, attempt_id, instance_name, new_state, status, start_time, end_time, reason, resources): scheduler_state_changed: Notice = app['scheduler_state_changed'] cancel_ready_state_changed: asyncio.Event = app[ 'cancel_ready_state_changed'] db: Database = app['db'] inst_coll_manager: 'InstanceCollectionManager' = app['inst_coll_manager'] task_manager: BackgroundTaskManager = app['task_manager'] id = (batch_id, job_id) log.info(f'marking job {id} complete new_state {new_state}') now = time_msecs() try: rv = await db.execute_and_fetchone( 'CALL mark_job_complete(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s);', (batch_id, job_id, attempt_id, instance_name, new_state, json.dumps(status) if status is not None else None, start_time, end_time, reason, now)) except Exception: log.exception( f'error while marking job {id} complete on instance {instance_name}' ) raise scheduler_state_changed.notify() cancel_ready_state_changed.set() instance = None if instance_name: instance = inst_coll_manager.get_instance(instance_name) if instance: if rv['delta_cores_mcpu'] != 0 and instance.state == 'active': # may also create scheduling opportunities, set above instance.adjust_free_cores_in_memory(rv['delta_cores_mcpu']) else: log.warning(f'mark_complete for job {id} from unknown {instance}') await add_attempt_resources(db, batch_id, job_id, attempt_id, resources) if rv['rc'] != 0: log.info(f'mark_job_complete returned {rv} for job {id}') return old_state = rv['old_state'] if old_state in complete_states: log.info(f'old_state {old_state} complete for job {id}, doing nothing') # already complete, do nothing return log.info(f'job {id} changed state: {rv["old_state"]} => {new_state}') await notify_batch_job_complete(db, batch_id) if instance and not instance.inst_coll.is_pool and instance.state == 'active': task_manager.ensure_future(instance.kill())
async def post_job_complete(self, job, run_duration): try: await self.post_job_complete_1(job, run_duration) finally: log.info(f'{job} marked complete, removing from jobs') if job.id in self.jobs: del self.jobs[job.id] self.last_updated = time_msecs()
async def monitor_instances(self): if self.instances_by_last_updated: # 0 is the smallest (oldest) instance = self.instances_by_last_updated[0] since_last_updated = time_msecs() - instance.last_updated if since_last_updated > 60 * 1000: log.info(f'checking on {instance}, last updated {since_last_updated / 1000}s ago') await self.check_on_instance(instance)
def __init__(self, transfer: Union[Transfer, List[Transfer]]): self._start_time = time_msecs() self._end_time = None self._duration = None if isinstance(transfer, Transfer): self._transfer_report = TransferReport(transfer) else: self._transfer_report = [TransferReport(t) for t in transfer] self._exception: Optional[Exception] = None
async def __aenter__(self): if self.container.job.deleted: raise JobDeletedError() if self.state: log.info(f'{self.container} state changed: {self.container.state} => {self.state}') self.container.state = self.state self.timing = {} self.timing['start_time'] = time_msecs() self.container.timing[self.name] = self.timing
async def activate_instance_1(request, instance): body = await request.json() ip_address = body['ip_address'] log.info(f'activating {instance}') timestamp = time_msecs() token = await instance.activate(ip_address, timestamp) await instance.mark_healthy() return web.json_response({'token': token})
def __init__(self): self.cores_mcpu = CORES * 1000 self.last_updated = time_msecs() self.cpu_sem = FIFOWeightedSemaphore(self.cores_mcpu) self.pool = concurrent.futures.ThreadPoolExecutor() self.jobs = {} # filled in during activation self.log_store = None self.headers = None
async def post_job_complete(self, job): try: await self.post_job_complete_1(job) except Exception: log.exception(f'error while marking {job} complete', stack_info=True) finally: log.info(f'{job} marked complete, removing from jobs') if job.id in self.jobs: del self.jobs[job.id] self.last_updated = time_msecs()
async def post_create_resource(request, userdata): # pylint: disable=unused-argument db = request.app['db'] storage_client = request.app['storage_client'] checked_csrf = False attachments = {} post = {} reader = aiohttp.MultipartReader(request.headers, request.content) while True: part = await reader.next() # pylint: disable=not-callable if not part: break if part.name == '_csrf': # check csrf token # form fields are delivered in ordrer, the _csrf hidden field should appear first # https://stackoverflow.com/questions/7449861/multipart-upload-form-is-order-guaranteed token1 = request.cookies.get('_csrf') token2 = await part.text() if token1 is None or token2 is None or token1 != token2: log.info('request made with invalid csrf tokens') raise web.HTTPUnauthorized() checked_csrf = True elif part.name == 'file': if not checked_csrf: raise web.HTTPUnauthorized() filename = part.filename if not filename: continue attachment_id = secret_alnum_string() async with await storage_client.insert_object( BUCKET, f'atgu/attachments/{attachment_id}') as f: while True: chunk = await part.read_chunk() if not chunk: break await f.write(chunk) attachments[attachment_id] = filename else: post[part.name] = await part.text() if not checked_csrf: raise web.HTTPUnauthorized() now = time_msecs() id = await db.execute_insertone( ''' INSERT INTO `atgu_resources` (`time_created`, `title`, `description`, `contents`, `tags`, `attachments`, `time_updated`) VALUES (%s, %s, %s, %s, %s, %s, %s); ''', (now, post['title'], post['description'], post['contents'], post['tags'], json.dumps(attachments), now), ) return web.HTTPFound(deploy_config.external_url('atgu', f'/resources/{id}'))
async def main_loop(self): delay_secs = self.min_delay_secs while True: try: start_time = time_msecs() while True: self.event.clear() should_wait = await self.handler() if should_wait: await self.event.wait() except Exception: end_time = time_msecs() log.exception('caught exception in event handler loop') t = delay_secs * random.uniform(0.7, 1.3) await asyncio.sleep(t) ran_for_secs = (end_time - start_time) * 1000 delay_secs = min(max(self.min_delay_secs, 2 * delay_secs - min(0, (ran_for_secs - t) / 2)), 30.0)
async def activate_instance_1(request, instance): body = await request.json() ip_address = body['ip_address'] log.info(f'activating {instance}') timestamp = time_msecs() token = await instance.activate(ip_address, timestamp) await instance.mark_healthy() with open('/gsa-key/key.json', 'r') as f: key = json.loads(f.read()) return web.json_response({'token': token, 'key': key})
async def delete_orphaned_disks( compute_client: aiogoogle.GoogleComputeClient, zones: List[str], inst_coll_manager: InstanceCollectionManager, namespace: str, ): log.info('deleting orphaned disks') params = {'filter': f'(labels.namespace = {namespace})'} for zone in zones: async for disk in await compute_client.list(f'/zones/{zone}/disks', params=params): disk_name = disk['name'] instance_name = disk['labels']['instance-name'] instance = inst_coll_manager.get_instance(instance_name) creation_timestamp_msecs = parse_timestamp_msecs( disk.get('creationTimestamp')) last_attach_timestamp_msecs = parse_timestamp_msecs( disk.get('lastAttachTimestamp')) last_detach_timestamp_msecs = parse_timestamp_msecs( disk.get('lastDetachTimestamp')) now_msecs = time_msecs() if instance is None: log.exception( f'deleting disk {disk_name} from instance that no longer exists' ) elif last_attach_timestamp_msecs is None and now_msecs - creation_timestamp_msecs > 60 * 60 * 1000: log.exception( f'deleting disk {disk_name} that has not attached within 60 minutes' ) elif last_detach_timestamp_msecs is not None and now_msecs - last_detach_timestamp_msecs > 5 * 60 * 1000: log.exception( f'deleting detached disk {disk_name} that has not been cleaned up within 5 minutes' ) else: continue try: await compute_client.delete_disk( f'/zones/{zone}/disks/{disk_name}') except asyncio.CancelledError: raise except Exception as e: if isinstance(e, aiohttp.ClientResponseError) and e.status == 404: # pylint: disable=no-member continue log.exception( f'error while deleting orphaned disk {disk_name}')
async def create_instance_with_error_handling( batch_id: int, job_id: int, attempt_id: str, record: dict, id: Tuple[int, int] ): try: batch_format_version = BatchFormatVersion(record['format_version']) spec = json.loads(record['spec']) machine_spec = batch_format_version.get_spec_machine_spec(spec) instance, total_resources_on_instance = await self.create_instance(machine_spec) log.info(f'created {instance} for {(batch_id, job_id)}') await mark_job_creating( self.app, batch_id, job_id, attempt_id, instance, time_msecs(), total_resources_on_instance ) except Exception: log.exception(f'while creating job private instance for job {id}', exc_info=True)
async def create(app, name, activation_token, worker_cores_mcpu, zone): db = app['db'] state = 'pending' now = time_msecs() token = secrets.token_urlsafe(32) await db.just_execute( ''' INSERT INTO instances (name, state, activation_token, token, cores_mcpu, free_cores_mcpu, time_created, last_updated, version, zone) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s); ''', (name, state, activation_token, token, worker_cores_mcpu, worker_cores_mcpu, now, now, INSTANCE_VERSION, zone)) return Instance(app, name, state, worker_cores_mcpu, worker_cores_mcpu, now, 0, now, None, INSTANCE_VERSION, zone)
async def mark_job_complete(app, batch_id, job_id, attempt_id, instance_name, new_state, status, start_time, end_time, reason): scheduler_state_changed = app['scheduler_state_changed'] cancel_ready_state_changed = app['cancel_ready_state_changed'] db = app['db'] inst_pool = app['inst_pool'] id = (batch_id, job_id) log.info(f'marking job {id} complete new_state {new_state}') now = time_msecs() try: rv = await db.execute_and_fetchone( 'CALL mark_job_complete(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s);', (batch_id, job_id, attempt_id, instance_name, new_state, json.dumps(status) if status is not None else None, start_time, end_time, reason, now)) except Exception: log.exception( f'error while marking job {id} complete on instance {instance_name}' ) raise scheduler_state_changed.set() cancel_ready_state_changed.set() if instance_name: instance = inst_pool.name_instance.get(instance_name) if instance: if rv['delta_cores_mcpu'] != 0 and instance.state == 'active': # may also create scheduling opportunities, set above instance.adjust_free_cores_in_memory(rv['delta_cores_mcpu']) else: log.warning(f'mark_complete for job {id} from unknown {instance}') if rv['rc'] != 0: log.info(f'mark_job_complete returned {rv} for job {id}') return old_state = rv['old_state'] if old_state in complete_states: log.info(f'old_state {old_state} complete for job {id}, doing nothing') # already complete, do nothing return log.info(f'job {id} changed state: {rv["old_state"]} => {new_state}') await notify_batch_job_complete(app, db, batch_id)
async def create_instance_with_error_handling( batch_id, job_id, attempt_id, record, id): try: batch_format_version = BatchFormatVersion( record['format_version']) spec = json.loads(record['spec']) machine_spec = batch_format_version.get_spec_machine_spec( spec) instance, resources = await self.create_instance( batch_id, job_id, machine_spec) await mark_job_creating(self.app, batch_id, job_id, attempt_id, instance, time_msecs(), resources) except Exception: log.info(f'creating job private instance for job {id}', exc_info=True)
async def create(app, inst_coll, name, activation_token, worker_cores_mcpu, zone, machine_type, preemptible, worker_config: WorkerConfig): db: Database = app['db'] state = 'pending' now = time_msecs() token = secrets.token_urlsafe(32) await db.just_execute( ''' INSERT INTO instances (name, state, activation_token, token, cores_mcpu, free_cores_mcpu, time_created, last_updated, version, zone, inst_coll, machine_type, preemptible, worker_config) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); ''', ( name, state, activation_token, token, worker_cores_mcpu, worker_cores_mcpu, now, now, INSTANCE_VERSION, zone, inst_coll.name, machine_type, preemptible, base64.b64encode(json.dumps(worker_config.config).encode()).decode() ), ) return Instance( app, inst_coll, name, state, worker_cores_mcpu, worker_cores_mcpu, now, 0, now, None, INSTANCE_VERSION, zone, machine_type, preemptible, worker_config )
async def ui_get_job(request, userdata): app = request.app db = app['db'] batch_id = int(request.match_info['batch_id']) job_id = int(request.match_info['job_id']) user = userdata['username'] job_status = await _get_job(app, batch_id, job_id, user) attempts = [ attempt async for attempt in db.select_and_fetchall( ''' SELECT * FROM attempts WHERE batch_id = %s AND job_id = %s ''', (batch_id, job_id)) ] for attempt in attempts: start_time = attempt['start_time'] if start_time: attempt['start_time'] = time_msecs_str(start_time) else: del attempt['start_time'] end_time = attempt['end_time'] if end_time is not None: attempt['end_time'] = time_msecs_str(end_time) else: del attempt['end_time'] if start_time is not None: # elapsed time if attempt is still running if end_time is None: end_time = time_msecs() duration_msecs = max(end_time - start_time, 0) attempt['duration'] = humanize_timedelta_msecs(duration_msecs) page_context = { 'batch_id': batch_id, 'job_id': job_id, 'job_log': await _get_job_log(app, batch_id, job_id, user), 'attempts': attempts, 'job_status': json.dumps(job_status, indent=2) } return await render_template('batch', request, userdata, 'job.html', page_context)
async def delete_orphaned_disks(self): log.info('deleting orphaned disks') params = {'filter': f'(labels.namespace = {DEFAULT_NAMESPACE})'} for zone in self.zone_monitor.zones: async for disk in await self.compute_client.list( f'/zones/{zone}/disks', params=params): disk_name = disk['name'] instance_name = disk['labels']['instance-name'] instance = self.inst_coll_manager.get_instance(instance_name) creation_timestamp_msecs = parse_timestamp_msecs( disk.get('creationTimestamp')) last_attach_timestamp_msecs = parse_timestamp_msecs( disk.get('lastAttachTimestamp')) last_detach_timestamp_msecs = parse_timestamp_msecs( disk.get('lastDetachTimestamp')) now_msecs = time_msecs() if instance is None: log.exception( f'deleting disk {disk_name} from instance that no longer exists' ) elif (last_attach_timestamp_msecs is None and now_msecs - creation_timestamp_msecs > 60 * 60 * 1000): log.exception( f'deleting disk {disk_name} that has not attached within 60 minutes' ) elif (last_detach_timestamp_msecs is not None and now_msecs - last_detach_timestamp_msecs > 5 * 60 * 1000): log.exception( f'deleting detached disk {disk_name} that has not been cleaned up within 5 minutes' ) else: continue try: await self.compute_client.delete_disk( f'/zones/{zone}/disks/{disk_name}') except aiohttp.ClientResponseError as e: if e.status == 404: continue log.exception( f'error while deleting orphaned disk {disk_name}')
async def instance_monitoring_loop(self): log.info(f'starting instance monitoring loop') while True: try: if self.instances_by_last_updated: # 0 is the smallest (oldest) instance = self.instances_by_last_updated[0] since_last_updated = time_msecs() - instance.last_updated if since_last_updated > 60 * 1000: log.info(f'checking on {instance}, last updated {since_last_updated / 1000}s ago') await self.check_on_instance(instance) except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception: log.exception('in monitor instances loop') await asyncio.sleep(1)
async def deactivate(self, reason, timestamp=None): if self._state in ('inactive', 'deleted'): return if not timestamp: timestamp = time_msecs() await check_call_procedure(self.db, 'CALL deactivate_instance(%s, %s, %s);', (self.name, reason, timestamp)) self.instance_pool.adjust_for_remove_instance(self) self._state = 'inactive' self._free_cores_mcpu = self.cores_mcpu self.instance_pool.adjust_for_add_instance(self) # there might be jobs to reschedule self.scheduler_state_changed.set()
async def _fetch(self, session: httpx.ClientSession) -> Tuple[str, int]: # https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#calling-post-oauth2exchange-to-get-an-acr-refresh-token data = { 'grant_type': 'access_token', 'service': self.acr_url, 'access_token': await self.aad_access_token.token(session), } async with await request_retry_transient_errors( session, 'POST', f'https://{self.acr_url}/oauth2/exchange', headers={'Content-Type': 'application/x-www-form-urlencoded'}, data=data, timeout=aiohttp.ClientTimeout(total=60), # type: ignore ) as resp: refresh_token: str = (await resp.json())['refresh_token'] expiration_time_ms = time_msecs( ) + 60 * 60 * 1000 # token expires in 3 hours so we refresh after 1 hour return refresh_token, expiration_time_ms