Exemple #1
0
 def handle_response(self, resp):
     if resp.type == aiohttp.WSMsgType.CLOSE:
         raise aiohttp.ServerDisconnectedError('Socket was closed by server. (code={resp.data})')
     if resp.type == aiohttp.WSMsgType.ERROR:
         raise FatalError(f'Error raised while waiting for response from server: {resp}.')
     assert resp.type == aiohttp.WSMsgType.TEXT, resp.type
     return resp.data
Exemple #2
0
def hwe_normalize(call_expr):
    mt = matrix_table_source('hwe_normalize/call_expr', call_expr)
    mt = mt.select_entries(__gt=call_expr.n_alt_alleles())
    mt = mt.annotate_rows(__AC=agg.sum(mt.__gt),
                          __n_called=agg.count_where(hl.is_defined(mt.__gt)))
    mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called))

    n_variants = mt.count_rows()
    if n_variants == 0:
        raise FatalError(
            "hwe_normalize: found 0 variants after filtering out monomorphic sites."
        )
    info(
        f"hwe_normalize: found {n_variants} variants after filtering out monomorphic sites."
    )

    mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called)
    mt = mt.annotate_rows(__hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt *
                                                       (2 - mt.__mean_gt) *
                                                       n_variants / 2))
    mt = mt.unfilter_entries()

    normalized_gt = hl.or_else(
        (mt.__gt - mt.__mean_gt) / mt.__hwe_scaled_std_dev, 0.0)
    return normalized_gt
Exemple #3
0
 def add_reference(self, config):
     resp = retry_response_returning_functions(
         requests.post,
         f'{self.url}/references/create', json=config, headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
Exemple #4
0
 def add_sequence(self, name, fasta_file, index_file):
     resp = retry_response_returning_functions(
         requests.post,
         f'{self.url}/references/sequence/set',
         json={'name': name, 'fasta_file': fasta_file, 'index_file': index_file},
         headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
Exemple #5
0
    def _request_type(self, ir, kind):
        code = self._render(ir)
        resp = retry_response_returning_functions(
            requests.post,
            f'{self.url}/type/{kind}', json=code, headers=self.headers)
        if resp.status_code == 400 or resp.status_code == 500:
            raise FatalError(resp.text)
        resp.raise_for_status()

        return resp.json()
Exemple #6
0
 def remove_sequence(self, name):
     resp = retry_response_returning_functions(
         requests.delete,
         f'{self.url}/references/sequence/delete',
         json={'name': name},
         headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
Exemple #7
0
 def parse_vcf_metadata(self, path):
     resp = retry_response_returning_functions(
         requests.post,
         f'{self.url}/parse-vcf-metadata',
         json={'path': path},
         headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
     return resp.json()
    async def _async_execute_untimed(self, ir):
        token = secret_alnum_string()
        with TemporaryDirectory(ensure_exists=False) as dir:
            async def create_inputs():
                with self.fs.open(dir + '/in', 'wb') as infile:
                    write_int(infile, ServiceBackend.EXECUTE)
                    write_str(infile, tmp_dir())
                    write_str(infile, self.billing_project)
                    write_str(infile, self.bucket)
                    write_str(infile, self.render(ir))
                    write_str(infile, token)

            async def create_batch():
                batch_attributes = self.batch_attributes
                if 'name' not in batch_attributes:
                    batch_attributes = {**batch_attributes, 'name': 'execute(...)'}
                bb = self.async_bc.create_batch(token=token, attributes=batch_attributes)

                j = bb.create_jvm_job([
                    'is.hail.backend.service.ServiceBackendSocketAPI2',
                    os.environ['HAIL_SHA'],
                    os.environ['HAIL_JAR_URL'],
                    batch_attributes['name'],
                    dir + '/in',
                    dir + '/out',
                ], mount_tokens=True)
                return (j, await bb.submit(disable_progress_bar=self.disable_progress_bar))

            _, (j, b) = await asyncio.gather(create_inputs(), create_batch())

            status = await b.wait(disable_progress_bar=self.disable_progress_bar)
            if status['n_succeeded'] != 1:
                raise ValueError(f'batch failed {status} {await j.log()}')


            with self.fs.open(dir + '/out', 'rb') as outfile:
                success = read_bool(outfile)
                if success:
                    s = read_str(outfile)
                    try:
                        resp = json.loads(s)
                    except json.decoder.JSONDecodeError as err:
                        raise ValueError(f'could not decode {s}') from err
                else:
                    jstacktrace = read_str(outfile)
                    raise FatalError(jstacktrace)

        typ = dtype(resp['type'])
        if typ == tvoid:
            x = None
        else:
            x = typ._convert_from_json_na(resp['value'])

        return x
Exemple #9
0
 def add_liftover(self, name, chain_file, dest_reference_genome):
     resp = retry_response_returning_functions(
         requests.post,
         f'{self.url}/references/liftover/add',
         json={'name': name, 'chain_file': chain_file,
               'dest_reference_genome': dest_reference_genome},
         headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
Exemple #10
0
    def execute(self, ir, timed=False):
        code = self._render(ir)
        resp = retry_response_returning_functions(requests.post,
                                                  f'{self.url}/execute',
                                                  json=code,
                                                  headers=self.headers)
        if resp.status_code == 400 or resp.status_code == 500:
            raise FatalError(resp.text)
        resp.raise_for_status()
        resp_json = resp.json()
        typ = dtype(resp_json['type'])
        value = typ._convert_from_json_na(resp_json['value'])
        # FIXME put back timings

        return (value, None) if timed else value
Exemple #11
0
    def unphased_diploid_gt_index(self):
        """Return the genotype index for unphased, diploid calls.

        Returns
        -------
        :obj:`int`
        """

        if self.ploidy != 2 or self.phased:
            raise FatalError(
                "'unphased_diploid_gt_index' is only valid for unphased, diploid calls. Found {}."
                .format(repr(self)))
        a0 = self._alleles[0]
        a1 = self._alleles[1]
        assert a0 <= a1
        return a1 * (a1 + 1) / 2 + a0
Exemple #12
0
 def import_fam(self, path: str, quant_pheno: bool, delimiter: str, missing: str):
     resp = retry_response_returning_functions(
         requests.post,
         f'{self.url}/import-fam',
         json={
             'path': path,
             'quant_pheno': quant_pheno,
             'delimiter': delimiter,
             'missing': missing
         },
         headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
     return resp.json()
Exemple #13
0
 def index_bgen(self, files, index_file_map, rg, contig_recoding, skip_invalid_loci):
     resp = retry_response_returning_functions(
         requests.post,
         f'{self.url}/index-bgen',
         json={
             'files': files,
             'index_file_map': index_file_map,
             'rg': rg,
             'contig_recoding': contig_recoding,
             'skip_invalid_loci': skip_invalid_loci
         },
         headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
     return resp.json()
Exemple #14
0
 def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par):
     resp = retry_response_returning_functions(
         requests.post,
         f'{self.url}/references/create/fasta',
         json={
             'name': name,
             'fasta_file': fasta_file,
             'index_file': index_file,
             'x_contigs': x_contigs,
             'y_contigs': y_contigs,
             'mt_contigs': mt_contigs,
             'par': par
         }, headers=self.headers)
     if resp.status_code == 400 or resp.status_code == 500:
         resp_json = resp.json()
         raise FatalError(resp_json['message'])
     resp.raise_for_status()
Exemple #15
0
 async def async_request(self, endpoint, **data):
     data['token'] = secret_alnum_string()
     session = await self.session()
     async with session.ws_connect(f'{self.url}/api/v1alpha/{endpoint}') as socket:
         await socket.send_str(json.dumps(data))
         response = await socket.receive()
         await socket.send_str('bye')
         if response.type == aiohttp.WSMsgType.ERROR:
             raise ValueError(f'bad response: {endpoint}; {data}; {response}')
         if response.type in (aiohttp.WSMsgType.CLOSE,
                              aiohttp.WSMsgType.CLOSED):
             warnings.warn(f'retrying after losing connection {endpoint}; {data}; {response}')
             raise TransientError()
         assert response.type == aiohttp.WSMsgType.TEXT
         result = json.loads(response.data)
         if result['status'] != 200:
             raise FatalError(f'Error from server: {result["value"]}')
         return result['value']
    def load_references_from_dataset(self, path):
        token = secret_alnum_string()
        with TemporaryDirectory(ensure_exists=False) as dir:
            with self.fs.open(dir + '/in', 'wb') as infile:
                write_int(infile, ServiceBackend.LOAD_REFERENCES_FROM_DATASET)
                write_str(infile, tmp_dir())
                write_str(infile, self.billing_project)
                write_str(infile, self.bucket)
                write_str(infile, path)

            batch_attributes = self.batch_attributes
            if 'name' not in batch_attributes:
                batch_attributes = {**batch_attributes, 'name': 'load_references_from_dataset(...)'}
            bb = self.bc.create_batch(token=token, attributes=batch_attributes)

            j = bb.create_jvm_job([
                'is.hail.backend.service.ServiceBackendSocketAPI2',
                os.environ['HAIL_SHA'],
                os.environ['HAIL_JAR_URL'],
                batch_attributes['name'],
                dir + '/in',
                dir + '/out',
            ], mount_tokens=True)
            b = bb.submit(disable_progress_bar=self.disable_progress_bar)
            status = b.wait(disable_progress_bar=self.disable_progress_bar)
            if status['n_succeeded'] != 1:
                raise ValueError(f'batch failed {status} {j.log()}')


            with self.fs.open(dir + '/out', 'rb') as outfile:
                success = read_bool(outfile)
                if success:
                    s = read_str(outfile)
                    try:
                        # FIXME: do we not have to parse the result?
                        return json.loads(s)
                    except json.decoder.JSONDecodeError as err:
                        raise ValueError(f'could not decode {s}') from err
                else:
                    jstacktrace = read_str(outfile)
                    raise FatalError(jstacktrace)
Exemple #17
0
def _make_tsm_from_call(call_expr,
                        block_size,
                        mean_center=False,
                        hwe_normalize=False):
    mt = matrix_table_source('_make_tsm/entry_expr', call_expr)
    mt = mt.select_entries(__gt=call_expr.n_alt_alleles())
    if mean_center or hwe_normalize:
        mt = mt.annotate_rows(__AC=agg.sum(mt.__gt),
                              __n_called=agg.count_where(hl.is_defined(
                                  mt.__gt)))
        mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called))

        n_variants = mt.count_rows()
        if n_variants == 0:
            raise FatalError(
                "_make_tsm: found 0 variants after filtering out monomorphic sites."
            )
        info(
            f"_make_tsm: found {n_variants} variants after filtering out monomorphic sites."
        )

        mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called)
        mt = mt.unfilter_entries()

        mt = mt.select_entries(__x=hl.or_else(mt.__gt - mt.__mean_gt, 0.0))

        if hwe_normalize:
            mt = mt.annotate_rows(
                __hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt *
                                             (2 - mt.__mean_gt) * n_variants /
                                             2))
            mt = mt.select_entries(__x=mt.__x / mt.__hwe_scaled_std_dev)
    else:
        mt = mt.select_entries(__x=mt.__gt)

    A, ht = mt_to_table_of_ndarray(mt.__x,
                                   block_size,
                                   return_checkpointed_table_also=True)
    A = A.persist()
    return TallSkinnyMatrix(A, A.ndarray, ht, list(mt.col_key))
    def blockmatrix_type(self, bmir):
        token = secret_alnum_string()
        with TemporaryDirectory(ensure_exists=False) as dir:
            with self.fs.open(dir + '/in', 'wb') as infile:
                write_int(infile, ServiceBackend.BLOCK_MATRIX_TYPE)
                write_str(infile, tmp_dir())
                write_str(infile, self.render(bmir))

            batch_attributes = self.batch_attributes
            if 'name' not in batch_attributes:
                batch_attributes = {**batch_attributes, 'name': 'blockmatrix_type(...)'}
            bb = self.bc.create_batch(token=token, attributes=batch_attributes)

            j = bb.create_jvm_job([
                'is.hail.backend.service.ServiceBackendSocketAPI2',
                os.environ['HAIL_SHA'],
                os.environ['HAIL_JAR_URL'],
                batch_attributes['name'],
                dir + '/in',
                dir + '/out',
            ], mount_tokens=True)
            b = bb.submit(disable_progress_bar=self.disable_progress_bar)
            status = b.wait(disable_progress_bar=self.disable_progress_bar)
            if status['n_succeeded'] != 1:
                raise ValueError(f'batch failed {status} {j.log()}')


            with self.fs.open(dir + '/out', 'rb') as outfile:
                success = read_bool(outfile)
                if success:
                    s = read_str(outfile)
                    try:
                        return tblockmatrix._from_json(json.loads(s))
                    except json.decoder.JSONDecodeError as err:
                        raise ValueError(f'could not decode {s}') from err
                else:
                    jstacktrace = read_str(outfile)
                    raise FatalError(jstacktrace)
Exemple #19
0
    async def _rpc(self, name: str, inputs: Callable[[afs.WritableStream, str],
                                                     Awaitable[None]]):
        timings = Timings()
        token = secret_alnum_string()
        iodir = TemporaryDirectory(
            ensure_exists=False).name  # FIXME: actually cleanup
        with TemporaryDirectory(ensure_exists=False) as _:
            with timings.step("write input"):
                async with await self._async_fs.create(iodir +
                                                       '/in') as infile:
                    nonnull_flag_count = sum(v is not None
                                             for v in self.flags.values())
                    await write_int(infile, nonnull_flag_count)
                    for k, v in self.flags.items():
                        if v is not None:
                            await write_str(infile, k)
                            await write_str(infile, v)
                    await inputs(infile, token)

            with timings.step("submit batch"):
                batch_attributes = self.batch_attributes
                if 'name' not in batch_attributes:
                    batch_attributes = {**batch_attributes, 'name': name}
                bb = self.async_bc.create_batch(token=token,
                                                attributes=batch_attributes)

                j = bb.create_jvm_job([
                    ServiceBackend.DRIVER,
                    os.environ['HAIL_SHA'],
                    os.environ['HAIL_JAR_URL'],
                    batch_attributes['name'],
                    iodir + '/in',
                    iodir + '/out',
                ],
                                      mount_tokens=True,
                                      resources={
                                          'preemptible': False,
                                          'memory': 'standard'
                                      })
                b = await bb.submit(
                    disable_progress_bar=self.disable_progress_bar)

            with timings.step("wait batch"):
                try:
                    status = await b.wait(
                        disable_progress_bar=self.disable_progress_bar)
                except Exception:
                    await b.cancel()
                    raise

            with timings.step("parse status"):
                if status['n_succeeded'] != 1:
                    job_status = await j.status()
                    if 'status' in job_status:
                        if 'error' in job_status['status']:
                            job_status['status'][
                                'error'] = yaml_literally_shown_str(
                                    job_status['status']['error'].strip())
                    logs = await j.log()
                    for k in logs:
                        logs[k] = yaml_literally_shown_str(logs[k].strip())
                    message = {
                        'service_backend_debug_info': self.debug_info(),
                        'batch_status': status,
                        'job_status': job_status,
                        'log': logs
                    }
                    log.error(yaml.dump(message))
                    raise FatalError(message)

            with timings.step("read output"):
                async with await self._async_fs.open(iodir +
                                                     '/out') as outfile:
                    success = await read_bool(outfile)
                    if success:
                        json_bytes = await read_bytes(outfile)
                        try:
                            return token, orjson.loads(json_bytes), timings
                        except orjson.JSONDecodeError as err:
                            raise FatalError(
                                f'batch id was {b.id}\ncould not decode {json_bytes}'
                            ) from err
                    else:
                        jstacktrace = await read_str(outfile)
                        maybe_id = ServiceBackend.HAIL_BATCH_FAILURE_EXCEPTION_MESSAGE_RE.match(
                            jstacktrace)
                        if maybe_id:
                            batch_id = maybe_id.groups()[0]
                            b2 = await self.async_bc.get_batch(batch_id)
                            b2_status = await b2.status()
                            assert b2_status['state'] != 'success'
                            failed_jobs = []
                            async for j in b2.jobs():
                                if j['state'] != 'Success':
                                    logs, job = await asyncio.gather(
                                        self.async_bc.get_job_log(
                                            j['batch_id'], j['job_id']),
                                        self.async_bc.get_job(
                                            j['batch_id'], j['job_id']),
                                    )
                                    full_status = job._status
                                    if 'status' in full_status:
                                        if 'error' in full_status['status']:
                                            full_status['status'][
                                                'error'] = yaml_literally_shown_str(
                                                    full_status['status']
                                                    ['error'].strip())
                                    main_log = logs.get('main', '')
                                    failed_jobs.append({
                                        'partial_status':
                                        j,
                                        'status':
                                        full_status,
                                        'log':
                                        yaml_literally_shown_str(
                                            main_log.strip()),
                                    })
                            message = {
                                'id':
                                b.id,
                                'service_backend_debug_info':
                                self.debug_info(),
                                'stacktrace':
                                yaml_literally_shown_str(jstacktrace.strip()),
                                'cause': {
                                    'id': batch_id,
                                    'batch_status': b2_status,
                                    'failed_jobs': failed_jobs
                                }
                            }
                            log.error(yaml.dump(message))
                            raise FatalError(
                                orjson.dumps(message).decode('utf-8'))
                        raise FatalError(f'batch id was {b.id}\n' +
                                         jstacktrace)