def handle_response(self, resp): if resp.type == aiohttp.WSMsgType.CLOSE: raise aiohttp.ServerDisconnectedError('Socket was closed by server. (code={resp.data})') if resp.type == aiohttp.WSMsgType.ERROR: raise FatalError(f'Error raised while waiting for response from server: {resp}.') assert resp.type == aiohttp.WSMsgType.TEXT, resp.type return resp.data
def hwe_normalize(call_expr): mt = matrix_table_source('hwe_normalize/call_expr', call_expr) mt = mt.select_entries(__gt=call_expr.n_alt_alleles()) mt = mt.annotate_rows(__AC=agg.sum(mt.__gt), __n_called=agg.count_where(hl.is_defined(mt.__gt))) mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called)) n_variants = mt.count_rows() if n_variants == 0: raise FatalError( "hwe_normalize: found 0 variants after filtering out monomorphic sites." ) info( f"hwe_normalize: found {n_variants} variants after filtering out monomorphic sites." ) mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called) mt = mt.annotate_rows(__hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) * n_variants / 2)) mt = mt.unfilter_entries() normalized_gt = hl.or_else( (mt.__gt - mt.__mean_gt) / mt.__hwe_scaled_std_dev, 0.0) return normalized_gt
def add_reference(self, config): resp = retry_response_returning_functions( requests.post, f'{self.url}/references/create', json=config, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status()
def add_sequence(self, name, fasta_file, index_file): resp = retry_response_returning_functions( requests.post, f'{self.url}/references/sequence/set', json={'name': name, 'fasta_file': fasta_file, 'index_file': index_file}, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status()
def _request_type(self, ir, kind): code = self._render(ir) resp = retry_response_returning_functions( requests.post, f'{self.url}/type/{kind}', json=code, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: raise FatalError(resp.text) resp.raise_for_status() return resp.json()
def remove_sequence(self, name): resp = retry_response_returning_functions( requests.delete, f'{self.url}/references/sequence/delete', json={'name': name}, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status()
def parse_vcf_metadata(self, path): resp = retry_response_returning_functions( requests.post, f'{self.url}/parse-vcf-metadata', json={'path': path}, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status() return resp.json()
async def _async_execute_untimed(self, ir): token = secret_alnum_string() with TemporaryDirectory(ensure_exists=False) as dir: async def create_inputs(): with self.fs.open(dir + '/in', 'wb') as infile: write_int(infile, ServiceBackend.EXECUTE) write_str(infile, tmp_dir()) write_str(infile, self.billing_project) write_str(infile, self.bucket) write_str(infile, self.render(ir)) write_str(infile, token) async def create_batch(): batch_attributes = self.batch_attributes if 'name' not in batch_attributes: batch_attributes = {**batch_attributes, 'name': 'execute(...)'} bb = self.async_bc.create_batch(token=token, attributes=batch_attributes) j = bb.create_jvm_job([ 'is.hail.backend.service.ServiceBackendSocketAPI2', os.environ['HAIL_SHA'], os.environ['HAIL_JAR_URL'], batch_attributes['name'], dir + '/in', dir + '/out', ], mount_tokens=True) return (j, await bb.submit(disable_progress_bar=self.disable_progress_bar)) _, (j, b) = await asyncio.gather(create_inputs(), create_batch()) status = await b.wait(disable_progress_bar=self.disable_progress_bar) if status['n_succeeded'] != 1: raise ValueError(f'batch failed {status} {await j.log()}') with self.fs.open(dir + '/out', 'rb') as outfile: success = read_bool(outfile) if success: s = read_str(outfile) try: resp = json.loads(s) except json.decoder.JSONDecodeError as err: raise ValueError(f'could not decode {s}') from err else: jstacktrace = read_str(outfile) raise FatalError(jstacktrace) typ = dtype(resp['type']) if typ == tvoid: x = None else: x = typ._convert_from_json_na(resp['value']) return x
def add_liftover(self, name, chain_file, dest_reference_genome): resp = retry_response_returning_functions( requests.post, f'{self.url}/references/liftover/add', json={'name': name, 'chain_file': chain_file, 'dest_reference_genome': dest_reference_genome}, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status()
def execute(self, ir, timed=False): code = self._render(ir) resp = retry_response_returning_functions(requests.post, f'{self.url}/execute', json=code, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: raise FatalError(resp.text) resp.raise_for_status() resp_json = resp.json() typ = dtype(resp_json['type']) value = typ._convert_from_json_na(resp_json['value']) # FIXME put back timings return (value, None) if timed else value
def unphased_diploid_gt_index(self): """Return the genotype index for unphased, diploid calls. Returns ------- :obj:`int` """ if self.ploidy != 2 or self.phased: raise FatalError( "'unphased_diploid_gt_index' is only valid for unphased, diploid calls. Found {}." .format(repr(self))) a0 = self._alleles[0] a1 = self._alleles[1] assert a0 <= a1 return a1 * (a1 + 1) / 2 + a0
def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str): resp = retry_response_returning_functions( requests.post, f'{self.url}/import-fam', json={ 'path': path, 'quant_pheno': quant_pheno, 'delimiter': delimiter, 'missing': missing }, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status() return resp.json()
def index_bgen(self, files, index_file_map, rg, contig_recoding, skip_invalid_loci): resp = retry_response_returning_functions( requests.post, f'{self.url}/index-bgen', json={ 'files': files, 'index_file_map': index_file_map, 'rg': rg, 'contig_recoding': contig_recoding, 'skip_invalid_loci': skip_invalid_loci }, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status() return resp.json()
def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par): resp = retry_response_returning_functions( requests.post, f'{self.url}/references/create/fasta', json={ 'name': name, 'fasta_file': fasta_file, 'index_file': index_file, 'x_contigs': x_contigs, 'y_contigs': y_contigs, 'mt_contigs': mt_contigs, 'par': par }, headers=self.headers) if resp.status_code == 400 or resp.status_code == 500: resp_json = resp.json() raise FatalError(resp_json['message']) resp.raise_for_status()
async def async_request(self, endpoint, **data): data['token'] = secret_alnum_string() session = await self.session() async with session.ws_connect(f'{self.url}/api/v1alpha/{endpoint}') as socket: await socket.send_str(json.dumps(data)) response = await socket.receive() await socket.send_str('bye') if response.type == aiohttp.WSMsgType.ERROR: raise ValueError(f'bad response: {endpoint}; {data}; {response}') if response.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED): warnings.warn(f'retrying after losing connection {endpoint}; {data}; {response}') raise TransientError() assert response.type == aiohttp.WSMsgType.TEXT result = json.loads(response.data) if result['status'] != 200: raise FatalError(f'Error from server: {result["value"]}') return result['value']
def load_references_from_dataset(self, path): token = secret_alnum_string() with TemporaryDirectory(ensure_exists=False) as dir: with self.fs.open(dir + '/in', 'wb') as infile: write_int(infile, ServiceBackend.LOAD_REFERENCES_FROM_DATASET) write_str(infile, tmp_dir()) write_str(infile, self.billing_project) write_str(infile, self.bucket) write_str(infile, path) batch_attributes = self.batch_attributes if 'name' not in batch_attributes: batch_attributes = {**batch_attributes, 'name': 'load_references_from_dataset(...)'} bb = self.bc.create_batch(token=token, attributes=batch_attributes) j = bb.create_jvm_job([ 'is.hail.backend.service.ServiceBackendSocketAPI2', os.environ['HAIL_SHA'], os.environ['HAIL_JAR_URL'], batch_attributes['name'], dir + '/in', dir + '/out', ], mount_tokens=True) b = bb.submit(disable_progress_bar=self.disable_progress_bar) status = b.wait(disable_progress_bar=self.disable_progress_bar) if status['n_succeeded'] != 1: raise ValueError(f'batch failed {status} {j.log()}') with self.fs.open(dir + '/out', 'rb') as outfile: success = read_bool(outfile) if success: s = read_str(outfile) try: # FIXME: do we not have to parse the result? return json.loads(s) except json.decoder.JSONDecodeError as err: raise ValueError(f'could not decode {s}') from err else: jstacktrace = read_str(outfile) raise FatalError(jstacktrace)
def _make_tsm_from_call(call_expr, block_size, mean_center=False, hwe_normalize=False): mt = matrix_table_source('_make_tsm/entry_expr', call_expr) mt = mt.select_entries(__gt=call_expr.n_alt_alleles()) if mean_center or hwe_normalize: mt = mt.annotate_rows(__AC=agg.sum(mt.__gt), __n_called=agg.count_where(hl.is_defined( mt.__gt))) mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called)) n_variants = mt.count_rows() if n_variants == 0: raise FatalError( "_make_tsm: found 0 variants after filtering out monomorphic sites." ) info( f"_make_tsm: found {n_variants} variants after filtering out monomorphic sites." ) mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called) mt = mt.unfilter_entries() mt = mt.select_entries(__x=hl.or_else(mt.__gt - mt.__mean_gt, 0.0)) if hwe_normalize: mt = mt.annotate_rows( __hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) * n_variants / 2)) mt = mt.select_entries(__x=mt.__x / mt.__hwe_scaled_std_dev) else: mt = mt.select_entries(__x=mt.__gt) A, ht = mt_to_table_of_ndarray(mt.__x, block_size, return_checkpointed_table_also=True) A = A.persist() return TallSkinnyMatrix(A, A.ndarray, ht, list(mt.col_key))
def blockmatrix_type(self, bmir): token = secret_alnum_string() with TemporaryDirectory(ensure_exists=False) as dir: with self.fs.open(dir + '/in', 'wb') as infile: write_int(infile, ServiceBackend.BLOCK_MATRIX_TYPE) write_str(infile, tmp_dir()) write_str(infile, self.render(bmir)) batch_attributes = self.batch_attributes if 'name' not in batch_attributes: batch_attributes = {**batch_attributes, 'name': 'blockmatrix_type(...)'} bb = self.bc.create_batch(token=token, attributes=batch_attributes) j = bb.create_jvm_job([ 'is.hail.backend.service.ServiceBackendSocketAPI2', os.environ['HAIL_SHA'], os.environ['HAIL_JAR_URL'], batch_attributes['name'], dir + '/in', dir + '/out', ], mount_tokens=True) b = bb.submit(disable_progress_bar=self.disable_progress_bar) status = b.wait(disable_progress_bar=self.disable_progress_bar) if status['n_succeeded'] != 1: raise ValueError(f'batch failed {status} {j.log()}') with self.fs.open(dir + '/out', 'rb') as outfile: success = read_bool(outfile) if success: s = read_str(outfile) try: return tblockmatrix._from_json(json.loads(s)) except json.decoder.JSONDecodeError as err: raise ValueError(f'could not decode {s}') from err else: jstacktrace = read_str(outfile) raise FatalError(jstacktrace)
async def _rpc(self, name: str, inputs: Callable[[afs.WritableStream, str], Awaitable[None]]): timings = Timings() token = secret_alnum_string() iodir = TemporaryDirectory( ensure_exists=False).name # FIXME: actually cleanup with TemporaryDirectory(ensure_exists=False) as _: with timings.step("write input"): async with await self._async_fs.create(iodir + '/in') as infile: nonnull_flag_count = sum(v is not None for v in self.flags.values()) await write_int(infile, nonnull_flag_count) for k, v in self.flags.items(): if v is not None: await write_str(infile, k) await write_str(infile, v) await inputs(infile, token) with timings.step("submit batch"): batch_attributes = self.batch_attributes if 'name' not in batch_attributes: batch_attributes = {**batch_attributes, 'name': name} bb = self.async_bc.create_batch(token=token, attributes=batch_attributes) j = bb.create_jvm_job([ ServiceBackend.DRIVER, os.environ['HAIL_SHA'], os.environ['HAIL_JAR_URL'], batch_attributes['name'], iodir + '/in', iodir + '/out', ], mount_tokens=True, resources={ 'preemptible': False, 'memory': 'standard' }) b = await bb.submit( disable_progress_bar=self.disable_progress_bar) with timings.step("wait batch"): try: status = await b.wait( disable_progress_bar=self.disable_progress_bar) except Exception: await b.cancel() raise with timings.step("parse status"): if status['n_succeeded'] != 1: job_status = await j.status() if 'status' in job_status: if 'error' in job_status['status']: job_status['status'][ 'error'] = yaml_literally_shown_str( job_status['status']['error'].strip()) logs = await j.log() for k in logs: logs[k] = yaml_literally_shown_str(logs[k].strip()) message = { 'service_backend_debug_info': self.debug_info(), 'batch_status': status, 'job_status': job_status, 'log': logs } log.error(yaml.dump(message)) raise FatalError(message) with timings.step("read output"): async with await self._async_fs.open(iodir + '/out') as outfile: success = await read_bool(outfile) if success: json_bytes = await read_bytes(outfile) try: return token, orjson.loads(json_bytes), timings except orjson.JSONDecodeError as err: raise FatalError( f'batch id was {b.id}\ncould not decode {json_bytes}' ) from err else: jstacktrace = await read_str(outfile) maybe_id = ServiceBackend.HAIL_BATCH_FAILURE_EXCEPTION_MESSAGE_RE.match( jstacktrace) if maybe_id: batch_id = maybe_id.groups()[0] b2 = await self.async_bc.get_batch(batch_id) b2_status = await b2.status() assert b2_status['state'] != 'success' failed_jobs = [] async for j in b2.jobs(): if j['state'] != 'Success': logs, job = await asyncio.gather( self.async_bc.get_job_log( j['batch_id'], j['job_id']), self.async_bc.get_job( j['batch_id'], j['job_id']), ) full_status = job._status if 'status' in full_status: if 'error' in full_status['status']: full_status['status'][ 'error'] = yaml_literally_shown_str( full_status['status'] ['error'].strip()) main_log = logs.get('main', '') failed_jobs.append({ 'partial_status': j, 'status': full_status, 'log': yaml_literally_shown_str( main_log.strip()), }) message = { 'id': b.id, 'service_backend_debug_info': self.debug_info(), 'stacktrace': yaml_literally_shown_str(jstacktrace.strip()), 'cause': { 'id': batch_id, 'batch_status': b2_status, 'failed_jobs': failed_jobs } } log.error(yaml.dump(message)) raise FatalError( orjson.dumps(message).decode('utf-8')) raise FatalError(f'batch id was {b.id}\n' + jstacktrace)