示例#1
0
文件: rcr3.py 项目: dckc/grouse
class DateShiftFixPart(et.SqlScriptTask):
    script = Script.date_shift_2015_part
    cms_rif_schema = pv.StrParam(default='CMS_DEID_2015')
    upload_id = pv.IntParam()

    @property
    def variables(self) -> Environment:
        return dict(upload_id=str(self.upload_id))
class MigrateUpload(UploadRunTask):
    """
    TODO: explain of why run() is trivial, i.e. why we
    don't get an upload_status record until the end. Or fix it.
    """
    upload_id = pv.IntParam()
    workspace_schema = pv.StrParam(default='HERON_ETL_1')
    i2b2_deid = pv.StrParam(default='BlueHeronData')
    db_url_deid = pv.StrParam()
    log_dest = pv.PathParam(significant=False)

    @property
    def label(self) -> str:
        return self.get_task_family()

    @property
    def source_cd(self) -> str:
        return self.workspace_schema

    def requires(self) -> Dict[str, luigi.Task]:
        if self.complete():
            return {}

        _configure_logging(self.log_dest)
        return {
            'id':
            CopyRecords(
                src_table=
                f'{self.workspace_schema}.observation_fact_{self.upload_id}',
                dest_table=f'{self.schema}.observation_fact',
                mode='append',
                db_url=self.db_url,
                db_url_dest=self.db_url,
                driver=self.driver,
                user=self.user,
                passkey=self.passkey),
            'deid':
            CopyRecords(
                src_table=
                f'{self.workspace_schema}.observation_fact_deid_{self.upload_id}',
                dest_table=f'{self.i2b2_deid}.observation_fact',
                mode='append',
                db_url=self.db_url,
                db_url_dest=self.db_url_deid,
                driver=self.driver,
                user=self.user,
                passkey=self.passkey),
        }

    def _upload_target(self) -> UploadTarget:
        return UploadTarget(self, self.schema, transform_name=self.task_id)

    def run_upload(self, conn: Connection, upload_id: int) -> None:
        pass
class NAACCR_Visits(_NAACCR_JDBC):
    """Make a per-tumor table for use in encounter_mapping etc.
    """
    design_id = pv.StrParam('patient_num')
    table_name = "NAACCR_TUMORS"
    encounter_num_start = pv.IntParam(description='see client.cfg')

    def _data(self, spark: SparkSession,
              naaccr_text_lines: DataFrame) -> DataFrame:
        tumors = td.TumorKeys.with_tumor_id(
            td.TumorKeys.pat_tmr(spark, naaccr_text_lines))
        tumors = td.TumorKeys.with_rownum(tumors,
                                          start=self.encounter_num_start)
        return tumors
class NAACCR_FlatFile(ManualTask):
    """A NAACCR flat file is determined by the registry, export date,
    and version.
    """
    naaccrRecordVersion = pv.IntParam(default=180)
    dateCaseReportExported = pv.DateParam()
    npiRegistryId = pv.StrParam()
    testData = pv.BoolParam(default=False, significant=False)
    flat_file = pv.PathParam(significant=False)
    record_qty_min = pv.IntParam(significant=False, default=1)

    def check_version_param(self) -> None:
        """Only version 18 (180) is currently supported.
        """
        if self.naaccrRecordVersion != 180:
            raise NotImplementedError()

    def complete(self) -> bool:
        with task_action(self, 'complete') as ctx:
            result = self.complete_action()
            ctx.add_success_fields(result=result)
            return result

    def complete_action(self) -> bool:
        """Check the first record, assuming all the others have
        the same export date and registry NPI.
        """
        self.check_version_param()

        with self.flat_file.open() as records:
            record0 = records.readline()
            qty = 1 + sum(1 for _ in records.readlines())
        log.info('record qty: %d (> %d? %s)', qty, self.record_qty_min,
                 qty >= self.record_qty_min)

        vOk = self._checkItem(record0, 'naaccrRecordVersion',
                              str(self.naaccrRecordVersion))
        regOk = self._checkItem(record0, 'npiRegistryId', self.npiRegistryId)
        dtOk = self._checkItem(record0, 'dateCaseReportExported',
                               self.dateCaseReportExported.strftime('%Y%m%d'))

        if vOk and regOk and dtOk and qty >= self.record_qty_min:
            return True
        else:
            if self.testData:
                log.warn('ignoring failed FlatFile check')
                return True
            return False

    @classmethod
    def _checkItem(cls, record: str, naaccrId: str, expected: str) -> bool:
        '''
        >>> npi = '1234567890'
        >>> record0 = ' ' * 19 + npi
        >>> NAACCR_FlatFile._checkItem(record0, 'npiRegistryId', npi)
        True
        >>> NAACCR_FlatFile._checkItem(record0, 'npiRegistryId', 'XXX')
        False
        '''
        itemDef = tr_ont.NAACCR1.itemDef(naaccrId)
        [startColumn, length
         ] = [int(itemDef.attrib[it]) for it in ['startColumn', 'length']]
        startColumn -= 1
        actual = record[startColumn:startColumn + length]
        if actual != expected:
            log.warn('%s: expected %s [%s:%s] = {%s} but found {%s}',
                     cls.__name__, naaccrId, startColumn - 1,
                     startColumn + length, expected, actual)
        return actual == expected
class NAACCR_Ontology1(JDBCTask):
    table_name = pv.StrParam(default="NAACCR_ONTOLOGY")
    who_cache = pv.PathParam()
    z_design_id = pv.StrParam(
        default='2019-12-16 pystdlib %s' %
        _stable_hash(tr_ont.NAACCR_I2B2.ont_script.code),
        description='''
        mnemonic for latest visible change to output.
        Changing this causes task_id to change, which
        ensures the ontology gets rebuilt if necessary.
        '''.strip(),
    )
    naaccr_version = pv.IntParam(default=18)  # ISSUE: ignored?
    jdbc_driver_jar = pv.StrParam(significant=False)

    # based on custom_meta
    col_to_type = dict(
        c_hlevel='int',
        c_fullname='varchar(700)',
        c_name='varchar(2000)',
        c_visualattributes='varchar(3)',
        c_totalnum='int',
        c_basecode='varchar(50)',
        c_dimcode='varchar(700)',
        c_tooltip='varchar(900)',
        update_date='date',
        sourcesystem_cd='varchar(50)',
    )
    coltypes = ','.join(f'{name} {ty}' for (name, ty) in col_to_type.items())

    @property
    def version_name(self) -> str:
        """version info that fits in an i2b2 name (50 characters)
        """
        task_hash = self.task_id.split('_')[
            -1]  # hmm... luigi doesn't export this
        return f'v{self.naaccr_version}-{task_hash}'

    @property
    def task_hash(self) -> str:
        return self.task_id.split('_')[-1]  # hmm... luigi doesn't export this

    def output(self) -> JDBCTableTarget:
        query = fr"""
          (select 1 from {self.table_name}
           where c_fullname = '{tr_ont.NAACCR_I2B2.top_folder}'
           and c_basecode = '{self.version_name}')
        """
        return JDBCTableTarget(self, query)

    @property
    def classpath(self) -> str:
        return self.jdbc_driver_jar

    @property
    def __password(self) -> str:
        from os import environ  # ISSUE: ambient
        return environ[self.passkey]

    def account(self) -> td.Account:
        from subprocess import Popen  # ISSUE: AMBIENT
        return td.Account('DEID',
                          self.user,
                          self.__password,
                          Popen,
                          url=self.db_url,
                          driver=self.driver)

    def run(self) -> None:
        conn = connect_mem(':memory:', detect_types=PARSE_COLNAMES)
        spark = DBSession(conn)
        update_date = dt.datetime.strptime(self.z_design_id[:10],
                                           '%Y-%m-%d').date()
        terms = tr_ont.NAACCR_I2B2.ont_view_in(spark,
                                               who_cache=self.who_cache,
                                               task_hash=self.task_hash,
                                               update_date=update_date)
        cdw = self.account()
        cdw.wr(self.table_name, td.case_fold(terms))
示例#6
0
文件: rcr3.py 项目: dckc/grouse
class CohortRIFTablePart(et.DBAccessTask, et.I2B2Task):
    """Build subset of one CMS RIF table where bene_id is from site cohorts.
    """
    cms_rif = pv.StrParam(description="Source RIF schema")
    work_schema = pv.StrParam(description="Destination RIF schema")
    table_name = pv.StrParam()
    site_star_list = ListParam(description='DATA_KUMC,DATA_MCW,...')
    parallel_degree = pv.IntParam(significant=False, default=16)

    @property
    def site_cohorts(self) -> SiteCohorts:
        return SiteCohorts(site_star_list=self.site_star_list)

    def requires(self) -> List[luigi.Task]:
        return [self.site_cohorts]

    def complete(self) -> bool:
        for t in self.requires():
            if not t.complete():
                return False

        with self.connection('rif_table_done') as conn:
            for cohort_id in self.site_cohorts._cohort_ids(
                    self.project.star_schema, conn):
                # We're not guaranteed that each site cohort intersects the CMS data,
                # but if it does, the CMS patient numbers are the low ones; hence min().
                lo, hi = conn.execute('''
                    select min(patient_num), max(patient_num)  from {i2b2}.qt_patient_set_collection
                    where result_instance_id = {id}
                    '''.format(i2b2=self.project.star_schema,
                               id=cohort_id)).first()
                try:
                    found = conn.execute('''
                        select 1 from {work}.{t} where bene_id between :lo and :hi and rownum = 1
                    '''.format(work=self.work_schema, t=self.table_name),
                                         params=dict(lo=lo, hi=hi)).scalar()
                except DatabaseError as oops:
                    conn.log.warning('complete query failed:', exc_info=oops)
                    return False
                if not found:
                    return False
        return True

    def run(self) -> None:
        with self.connection('rif_table') as conn:
            self._create(conn)

            cohort_ids = self.site_cohorts._cohort_ids(
                self.project.star_schema, conn)
            conn.execute('''
            insert /*+ append */ into {work}.{t}
            select /*+ parallel({degree}) */ * from {rif}.{t}
            where bene_id in (
              select patient_num from {i2b2}.qt_patient_set_collection
              where result_instance_id in ({cohorts})
            )
            '''.format(work=self.work_schema,
                       t=self.table_name,
                       degree=self.parallel_degree,
                       rif=self.cms_rif,
                       i2b2=self.project.star_schema,
                       cohorts=', '.join(str(i) for i in cohort_ids)))

    def _create(self, conn: et.LoggedConnection) -> None:
        try:
            conn.execute('''
            create table {work}.{t} as select * from {rif}.{t} where 1 = 0
            '''.format(work=self.work_schema,
                       t=self.table_name,
                       rif=self.cms_rif))
        except DatabaseError:
            pass  # perhaps it already exists...