class NAACCR_Patients(_NAACCR_JDBC):
    """Make a per-patient table for use in patient_mapping etc.
    """
    patient_ide_source = pv.StrParam(default='*****@*****.**')
    schema = pv.StrParam(default='NIGHTHERONDATA')
    z_design_id = pv.StrParam('keep unmapped patients')
    table_name = "NAACCR_PATIENTS"

    def requires(self) -> Dict[str, luigi.Task]:
        return dict(
            _NAACCR_JDBC.requires(self),
            HERON_Patient_Mapping=HERON_Patient_Mapping(
                patient_ide_source=self.patient_ide_source,
                schema=self.schema,
                db_url=self.db_url,
                jdbc_driver_jar=self.classpath,  # KLUDGE
                driver=self.driver,
                user=self.user,
                passkey=self.passkey))

    def _data(self, spark: SparkSession,
              naaccr_text_lines: DataFrame) -> DataFrame:
        patients = td.TumorKeys.patients(spark, naaccr_text_lines)
        cdw = self.account()
        patients = td.TumorKeys.with_patient_num(patients, spark, cdw,
                                                 self.schema,
                                                 self.patient_ide_source)
        return patients
class CopyRecords(SparkJDBCTask):
    src_table = pv.StrParam()
    dest_table = pv.StrParam()
    db_url_dest = pv.StrParam()
    mode = pv.StrParam(default='error')

    # IDEA (performance): we can go parallel if we set
    # partitionColumn, lowerBound, upperBound, numPartitions
    # ref. https://stackoverflow.com/a/35062411

    def output(self) -> luigi.LocalTarget:
        """Since making a DB connection is awkward and
        scanning the desination table may be expensive,
        let's cache status in the current directory.
        """
        # ISSUE: assumes linux paths
        return luigi.LocalTarget(path=f"task_status/{self.task_id}")

    def main_action(self, sc: SparkContext_T) -> None:
        spark = SparkSession(sc)
        with el.start_action(action_type=self.get_task_family(),
                             src=self.src_table,
                             dest=self.dest_table):
            src = self.account().rd(spark.read, self.src_table)
            with self.output().open(mode='w') as status:
                self.account(self.db_url_dest).wr(src.write,
                                                  self.dest_table,
                                                  mode=self.mode)
                status.write(self.task_id)
class MigrateUpload(UploadRunTask):
    """
    TODO: explain of why run() is trivial, i.e. why we
    don't get an upload_status record until the end. Or fix it.
    """
    upload_id = pv.IntParam()
    workspace_schema = pv.StrParam(default='HERON_ETL_1')
    i2b2_deid = pv.StrParam(default='BlueHeronData')
    db_url_deid = pv.StrParam()
    log_dest = pv.PathParam(significant=False)

    @property
    def label(self) -> str:
        return self.get_task_family()

    @property
    def source_cd(self) -> str:
        return self.workspace_schema

    def requires(self) -> Dict[str, luigi.Task]:
        if self.complete():
            return {}

        _configure_logging(self.log_dest)
        return {
            'id':
            CopyRecords(
                src_table=
                f'{self.workspace_schema}.observation_fact_{self.upload_id}',
                dest_table=f'{self.schema}.observation_fact',
                mode='append',
                db_url=self.db_url,
                db_url_dest=self.db_url,
                driver=self.driver,
                user=self.user,
                passkey=self.passkey),
            'deid':
            CopyRecords(
                src_table=
                f'{self.workspace_schema}.observation_fact_deid_{self.upload_id}',
                dest_table=f'{self.i2b2_deid}.observation_fact',
                mode='append',
                db_url=self.db_url,
                db_url_dest=self.db_url_deid,
                driver=self.driver,
                user=self.user,
                passkey=self.passkey),
        }

    def _upload_target(self) -> UploadTarget:
        return UploadTarget(self, self.schema, transform_name=self.task_id)

    def run_upload(self, conn: Connection, upload_id: int) -> None:
        pass
class NAACCR_Ontology2(_RunScriptTask):
    '''NAACCR Ontology: un-published code values

    i.e. code values that occur in tumor registry data but not in published ontologies
    '''
    # flat file attributes
    dateCaseReportExported = pv.DateParam()
    npiRegistryId = pv.StrParam()
    source_cd = pv.StrParam(default='*****@*****.**')

    z_design_id = pv.StrParam(default='length 50 (%s)' % _stable_hash(
        tr_ont.NAACCR_I2B2.ont_script.code,
        td.DataSummary.script.code))

    script_name = 'naaccr_concepts_mix.sql'
    script = res.read_text(heron_load, script_name)

    @property
    def classpath(self) -> str:
        return self.jdbc_driver_jar

    def _upload_target(self) -> 'UploadTarget':
        return UploadTarget(self, self.schema, transform_name=self.task_id)

    def requires(self) -> Dict[str, luigi.Task]:
        _configure_logging(self.log_dest)

        summary = NAACCR_Summary(
            db_url=self.db_url,
            user=self.user,
            passkey=self.passkey,
            dateCaseReportExported=self.dateCaseReportExported,
            npiRegistryId=self.npiRegistryId,
        )
        ont1 = NAACCR_Ontology1(
            db_url=self.db_url,
            user=self.user,
            passkey=self.passkey,
        )

        return dict(NAACCR_Ontology1=ont1,
                    NAACCR_Summary=summary)

    def run_upload(self, conn: Connection, upload_id: int) -> None:
        self.run_script(
            conn, self.script_name, self.script,
            variables=dict(upload_id=str(upload_id),
                           task_id=self.task_id),
            script_params=dict(
                upload_id=upload_id,
                task_id=self.task_id,
                source_cd=self.source_cd,
                update_date=self.dateCaseReportExported))
Beispiel #5
0
class CohortRIFTable(luigi.WrapperTask):
    cms_rif_schemas = ListParam(default=['CMS_DEID_2014', 'CMS_DEID_2015'])
    work_schema = pv.StrParam(description="Destination RIF schema")
    table_name = pv.StrParam()
    site_star_list = ListParam(description='DATA_KUMC,DATA_MCW,...')

    def requires(self) -> List[luigi.Task]:
        return [
            CohortRIFTablePart(cms_rif=cms_rif,
                               work_schema=self.work_schema,
                               site_star_list=self.site_star_list,
                               table_name=self.table_name)
            for cms_rif in self.cms_rif_schemas
        ]
class HelloNAACCR(SparkJDBCTask):
    """Verify connection to NAACCR ETL target DB.
    """
    schema = pv.StrParam(default='NIGHTHERONDATA')
    save_path = pv.StrParam(default='/tmp/upload_status.csv')

    def output(self) -> luigi.Target:
        return luigi.LocalTarget(self.save_path)

    def main_action(self, sparkContext: SparkContext_T) -> None:
        spark = SparkSession(sparkContext)
        upload_status = self.account().rd(spark.read,
                                          table=f'{self.schema}.upload_status')
        upload_status.write.save(self.save_path, format='csv')
Beispiel #7
0
class MigrateShiftedTable(et.UploadTask):
    script = Script.date_shift_normalize  # ISSUE: not used.
    source_table = pv.StrParam(description="source of shifted facts")
    source_task = SourceTaskParam()  # type: et.SourceTask

    @property
    def source(self) -> et.SourceTask:
        return self.source_task

    def run(self) -> None:
        upload = self._upload_target()
        with upload.job(
                self,
                label='migrate {src} to {star}.observation_fact'.format(
                    star=self.project.star_schema, src=self.source_table),
                user_id=et.make_url(self.account).username) as conn_id_r:
            conn, upload_id, result = conn_id_r

            migrate = 'insert /*+ parallel append */ into {star}.observation_fact select * from {src}'.format(
                star=self.project.star_schema, src=self.source_table)
            with conn._conn.begin():
                conn.execute(migrate)
                q = 'select count(*) from {src}'.format(src=self.source_table)
                rowcount = conn.execute(q).scalar()
            result[upload.table.c.loaded_record.name] = rowcount
Beispiel #8
0
class SiteI2B2(et.SourceTask, et.DBAccessTask):
    star_schema = pv.StrParam(description='BLUEHERONDATA_KUMC or the like')
    table_eg = 'patient_dimension'

    @property
    def source_cd(self) -> str:
        return "'%s'" % self.star_schema

    @property
    def download_date(self) -> datetime:
        with self.connection('download_date') as conn:
            t = conn.execute(
                '''
                select last_ddl_time from all_objects
                where owner=:owner and object_name=:table_eg
                ''',
                dict(owner=self.star_schema.upper(),
                     table_eg=self.table_eg.upper())).scalar()
        if not isinstance(t, datetime):
            raise ValueError(t)
        return t

    def _dbtarget(self) -> et.DBTarget:
        return et.SchemaTarget(self._make_url(self.account),
                               schema_name=self.star_schema,
                               table_eg=self.table_eg,
                               echo=self.echo)
Beispiel #9
0
class CohortRIF(luigi.WrapperTask):
    """Build subset CMS RIF tables where bene_id is from site cohorts
    in i2b2 patient sets.
    """
    cms_rif_schemas = ListParam(default=['CMS_DEID_2014', 'CMS_DEID_2015'])
    site_star_list = ListParam(description='DATA_KUMC,DATA_MCW,...')
    work_schema = pv.StrParam()

    def requires(self) -> List[luigi.Task]:
        table_names = ([
            'mbsf_ab_summary',
            'mbsf_d_cmpnts',
            'medpar_all',
            'bcarrier_claims',
            'bcarrier_line',
            'outpatient_base_claims',
            'outpatient_revenue_center',
            'pde',
            'pde_saf',
        ] if list(self.cms_rif_schemas) == ['CMS_DEID'] else [
            'mbsf_abcd_summary',
            'table medpar_all',
            'bcarrier_claims_k',
            'bcarrier_line_k',
            'outpatient_base_claims_k',
            'outpatient_revenue_center_k',
            'pde',
        ])

        return [
            CohortRIFTable(cms_rif_schemas=self.cms_rif_schemas,
                           work_schema=self.work_schema,
                           site_star_list=self.site_star_list,
                           table_name=table_name) for table_name in table_names
        ]
Beispiel #10
0
class DateShiftFixPart(et.SqlScriptTask):
    script = Script.date_shift_2015_part
    cms_rif_schema = pv.StrParam(default='CMS_DEID_2015')
    upload_id = pv.IntParam()

    @property
    def variables(self) -> Environment:
        return dict(upload_id=str(self.upload_id))
Beispiel #11
0
class CDM_Copy(luigi.WrapperTask):
    workspace = pv.StrParam()
    pcornet_cdm = pv.StrParam()
    tables = [
        'DEMOGRAPHIC',
        'ENCOUNTER',
        'ENROLLMENT',
        'DIAGNOSIS',
        'PROCEDURES',
        'DISPENSING',
        'HARVEST',
    ]

    def requires(self) -> List[luigi.Task]:
        return [
            CopyTable(table_name=t,
                      src_schema=self.workspace,
                      dest_schema=self.pcornet_cdm) for t in self.tables
        ]
class NAACCR_Summary(_NAACCR_JDBC):
    table_name = "NAACCR_EXPORT_STATS"

    z_design_id = pv.StrParam('fill NaN (%s)' %
                              _stable_hash(td.DataSummary.script.code))

    def _data(self, spark: SparkSession,
              naaccr_text_lines: DataFrame) -> DataFrame:
        dd = tr_ont.ddictDF(spark)
        extract = td.naaccr_read_fwf(naaccr_text_lines, dd)
        return td.DataSummary.stats(extract, spark).na.fill(0, subset=['sd'])
class UploadTask(JDBCTask):
    '''A task with an associated `upload_status` record.
    '''
    jdbc_driver_jar = pv.StrParam(significant=False)
    schema = pv.StrParam(description='owner of upload_status table')

    @property
    def classpath(self) -> str:
        return self.jdbc_driver_jar

    def complete(self) -> bool:
        with task_action(self, 'complete') as ctx:
            result = self.output().exists()
            ctx.add_success_fields(result=result)
            return result

    def output(self) -> luigi.Target:
        return self._upload_target()

    @abstractmethod
    def _upload_target(self) -> 'UploadTarget': pass
class SparkJDBCTask(PySparkTask, JDBCTask):
    """Support for JDBC access from spark tasks
    """
    driver_memory = pv.StrParam(default='4g', significant=False)
    executor_memory = pv.StrParam(default='4g', significant=False)
    max_result_size = pv.StrParam(default='4g', significant=False)

    def setup(self, conf):
        conf.set("spark.driver.maxResultSize", self.max_result_size)

    @abstractmethod
    def output(self) -> luigi.Target:
        pass

    def complete(self) -> bool:
        with task_action(self, 'complete') as ctx:
            result = self.output().exists()
            ctx.add_success_fields(result=result)
            return result

    def main(self, sparkContext: SparkContext_T, *_args: Any) -> None:
        with task_action(self, 'main'):
            self.main_action(sparkContext)

    @abstractmethod
    def main_action(self, sparkContext: SparkContext_T) -> None:
        pass

    @property
    def classpath(self) -> str:
        return ':'.join(self.jars)

    @property
    def __password(self) -> str:
        from os import environ  # ISSUE: ambient
        return environ[self.passkey]

    def account(self, db_url: Opt[str] = None) -> td.Account:
        return td.Account(self.user, self.__password, db_url or self.db_url,
                          self.driver)
class NAACCR_Visits(_NAACCR_JDBC):
    """Make a per-tumor table for use in encounter_mapping etc.
    """
    design_id = pv.StrParam('patient_num')
    table_name = "NAACCR_TUMORS"
    encounter_num_start = pv.IntParam(description='see client.cfg')

    def _data(self, spark: SparkSession,
              naaccr_text_lines: DataFrame) -> DataFrame:
        tumors = td.TumorKeys.with_tumor_id(
            td.TumorKeys.pat_tmr(spark, naaccr_text_lines))
        tumors = td.TumorKeys.with_rownum(tumors,
                                          start=self.encounter_num_start)
        return tumors
Beispiel #16
0
class CopyTable(et.DBAccessTask):
    table_name = pv.StrParam()
    src_schema = pv.StrParam()
    dest_schema = pv.StrParam()

    def complete(self) -> bool:
        with self.connection('check ' + self.table_name) as lc:
            try:
                dest_qty = lc.execute('select count(*) from {dest}.{t}'.format(
                    t=self.table_name, dest=self.dest_schema)).scalar()
            except:
                return False
            src_qty = lc.execute('select count(*) from {src}.{t}'.format(
                t=self.table_name, src=self.src_schema)).scalar()
        return src_qty == dest_qty

    def run(self) -> None:
        with self.connection('copy ' + self.table_name) as lc:
            lc.execute(
                'create table {dest}.{t} as select * from {src}.{t}'.format(
                    t=self.table_name,
                    src=self.src_schema,
                    dest=self.dest_schema))
class HERON_Patient_Mapping(UploadTask):
    patient_ide_source = pv.StrParam()
    # ISSUE: task id should depend on HERON release, too

    transform_name = 'load_epic_dimensions'

    def _upload_target(self) -> 'UploadTarget':
        return UploadTarget(self, self.schema, self.transform_name)

    @property
    def source_cd(self) -> str:
        return self.patient_ide_source

    def run(self) -> None:
        with task_action(self, 'run'):
            raise NotImplementedError('load_epic_dimensions is a paver task')
class NAACCR_Facts(_NAACCR_JDBC):
    table_name = "NAACCR_OBSERVATIONS"

    z_design_id = pv.StrParam(
        'nested fields (%s)' %
        _stable_hash(td.ItemObs.script.code, td.SEER_Recode.script.code,
                     td.SiteSpecificFactors.script1.code,
                     td.SiteSpecificFactors.script2.code))

    def _data(self, spark: SparkSession,
              naaccr_text_lines: DataFrame) -> DataFrame:
        dd = tr_ont.ddictDF(spark)
        extract = td.naaccr_read_fwf(naaccr_text_lines, dd)
        item = td.ItemObs.make(spark, extract)
        seer = td.SEER_Recode.make(spark, extract)
        ssf = td.SiteSpecificFactors.make(spark, extract)
        # ISSUE: make these separate tables?
        return item.union(seer).union(ssf)
Beispiel #19
0
class CDM_CMS_S7(luigi.Task):
    pcornet_cdm = 'CMS_CDM_11_15_7S'
    workspace = pv.StrParam()

    @property
    def copy_task(self) -> luigi.Task:
        return CDM_Copy(pcornet_cdm=self.pcornet_cdm, workspace=self.workspace)

    def requires(self) -> List[luigi.Task]:
        return [ShiftedDimensions()]

    def complete(self) -> bool:
        return (ShiftedDimensions().complete() and cms_i2p.I2P().complete()
                and self.copy_task.complete())

    def run(self) -> Iterable[luigi.Task]:
        # ISSUE: order matters
        yield cms_i2p.I2P()
        yield self.copy_task
class _NAACCR_JDBC(SparkJDBCTask):
    """Load data from a NAACCR flat file into a table via JDBC.

    Use a `task_id` column to manage freshness.
    """
    table_name: str
    dateCaseReportExported = pv.DateParam()
    npiRegistryId = pv.StrParam()

    def requires(self) -> Dict[str, luigi.Task]:
        return {
            'NAACCR_FlatFile':
            NAACCR_FlatFile(dateCaseReportExported=self.dateCaseReportExported,
                            npiRegistryId=self.npiRegistryId)
        }

    def _flat_file_task(self) -> NAACCR_FlatFile:
        return cast(NAACCR_FlatFile, self.requires()['NAACCR_FlatFile'])

    def output(self) -> luigi.Target:
        query = f"""
          (select 1 from {self.table_name}
           where task_id = '{self.task_id}')
        """
        return JDBCTableTarget(self, query)

    def main_action(self, sparkContext: SparkContext_T) -> None:
        quiet_logs(sparkContext)
        spark = SparkSession(sparkContext)
        ff = self._flat_file_task()
        naaccr_text_lines = spark.read.text(str(ff.flat_file))

        data = self._data(spark, naaccr_text_lines)
        # ISSUE: task_id is kinda long; how about just task_hash?
        # luigi_task_hash?
        data = data.withColumn('task_id', func.lit(self.task_id))
        data = td.case_fold(data)
        self.account().wr(data.write, self.table_name, mode='overwrite')

    def _data(self, spark: SparkSession,
              naaccr_text_lines: DataFrame) -> DataFrame:
        raise NotImplementedError('subclass must implement')
class NAACCR_Load(_RunScriptTask):
    '''Map and load NAACCR patients, tumors / visits, and facts.
    '''
    # flat file attributes
    dateCaseReportExported = pv.DateParam()
    npiRegistryId = pv.StrParam()

    # encounter mapping
    encounter_ide_source = pv.StrParam(default='*****@*****.**')
    project_id = pv.StrParam(default='BlueHeron')
    source_cd = pv.StrParam(default='*****@*****.**')

    # ISSUE: task_id should depend on dest schema / owner.
    z_design_id = pv.StrParam(default='nested fields')

    script_name = 'naaccr_facts_load.sql'
    script_deid_name = 'i2b2_facts_deid.sql'
    script = res.read_text(heron_load, script_name)
    script_deid = res.read_text(heron_load, script_deid_name)

    @property
    def classpath(self) -> str:
        return self.jdbc_driver_jar

    def _upload_target(self) -> 'UploadTarget':
        return UploadTarget(self, self.schema, transform_name=self.task_id)

    def requires(self) -> Dict[str, luigi.Task]:
        _configure_logging(self.log_dest)

        ff = NAACCR_FlatFile(
            dateCaseReportExported=self.dateCaseReportExported,
            npiRegistryId=self.npiRegistryId)

        parts = {
            cls.__name__:
            cls(db_url=self.db_url,
                user=self.user,
                passkey=self.passkey,
                dateCaseReportExported=self.dateCaseReportExported,
                npiRegistryId=self.npiRegistryId)
            for cls in [NAACCR_Patients, NAACCR_Visits, NAACCR_Facts]
        }
        return dict(parts, NAACCR_FlatFile=ff)

    def _flat_file_task(self) -> NAACCR_FlatFile:
        return cast(NAACCR_FlatFile, self.requires()['NAACCR_FlatFile'])

    def _patients_task(self) -> NAACCR_Patients:
        return cast(NAACCR_Patients, self.requires()['NAACCR_Patients'])

    def run_upload(self, conn: Connection, upload_id: int) -> None:
        ff = self._flat_file_task()
        pat = self._patients_task()

        # ISSUE: split these into separate tasks?
        for name, script in [(self.script_name, self.script),
                             (self.script_deid_name, self.script_deid)]:
            self.run_script(
                conn,
                name,
                script,
                variables=dict(upload_id=str(upload_id), task_id=self.task_id),
                script_params=dict(
                    upload_id=upload_id,
                    project_id=self.project_id,
                    task_id=self.task_id,
                    source_cd=self.source_cd,
                    download_date=ff.dateCaseReportExported,
                    patient_ide_source=pat.patient_ide_source,
                    encounter_ide_source=self.encounter_ide_source))
Beispiel #22
0
class CohortRIFTablePart(et.DBAccessTask, et.I2B2Task):
    """Build subset of one CMS RIF table where bene_id is from site cohorts.
    """
    cms_rif = pv.StrParam(description="Source RIF schema")
    work_schema = pv.StrParam(description="Destination RIF schema")
    table_name = pv.StrParam()
    site_star_list = ListParam(description='DATA_KUMC,DATA_MCW,...')
    parallel_degree = pv.IntParam(significant=False, default=16)

    @property
    def site_cohorts(self) -> SiteCohorts:
        return SiteCohorts(site_star_list=self.site_star_list)

    def requires(self) -> List[luigi.Task]:
        return [self.site_cohorts]

    def complete(self) -> bool:
        for t in self.requires():
            if not t.complete():
                return False

        with self.connection('rif_table_done') as conn:
            for cohort_id in self.site_cohorts._cohort_ids(
                    self.project.star_schema, conn):
                # We're not guaranteed that each site cohort intersects the CMS data,
                # but if it does, the CMS patient numbers are the low ones; hence min().
                lo, hi = conn.execute('''
                    select min(patient_num), max(patient_num)  from {i2b2}.qt_patient_set_collection
                    where result_instance_id = {id}
                    '''.format(i2b2=self.project.star_schema,
                               id=cohort_id)).first()
                try:
                    found = conn.execute('''
                        select 1 from {work}.{t} where bene_id between :lo and :hi and rownum = 1
                    '''.format(work=self.work_schema, t=self.table_name),
                                         params=dict(lo=lo, hi=hi)).scalar()
                except DatabaseError as oops:
                    conn.log.warning('complete query failed:', exc_info=oops)
                    return False
                if not found:
                    return False
        return True

    def run(self) -> None:
        with self.connection('rif_table') as conn:
            self._create(conn)

            cohort_ids = self.site_cohorts._cohort_ids(
                self.project.star_schema, conn)
            conn.execute('''
            insert /*+ append */ into {work}.{t}
            select /*+ parallel({degree}) */ * from {rif}.{t}
            where bene_id in (
              select patient_num from {i2b2}.qt_patient_set_collection
              where result_instance_id in ({cohorts})
            )
            '''.format(work=self.work_schema,
                       t=self.table_name,
                       degree=self.parallel_degree,
                       rif=self.cms_rif,
                       i2b2=self.project.star_schema,
                       cohorts=', '.join(str(i) for i in cohort_ids)))

    def _create(self, conn: et.LoggedConnection) -> None:
        try:
            conn.execute('''
            create table {work}.{t} as select * from {rif}.{t} where 1 = 0
            '''.format(work=self.work_schema,
                       t=self.table_name,
                       rif=self.cms_rif))
        except DatabaseError:
            pass  # perhaps it already exists...
Beispiel #23
0
class CMS_CDM_Report(et.DBAccessTask, et.I2B2Task):
    '''Make a report (spreadsheet) detailing how a CMS_CDM was produced.
    '''

    path = pv.StrParam(default='cms_cdm_report.xlsx')

    def output(self) -> luigi.LocalTarget:
        return luigi.LocalTarget(self.path)

    def run(self) -> None:
        # ISSUE: This isn't atomic like LocalTarget is supposed to be.
        writer = pd.ExcelWriter(self.path)
        query = rif_etl.read_sql_step
        with self.connection('reporting') as lc:
            self.inclusion_criteria(lc).to_excel(writer, 'Inclusion Criteria')
            self.cohorts(lc).to_excel(writer, 'Site Cohorts')
            query('select * from cohort_rif_summary',
                  lc).to_excel(writer, 'CMS RIF BC')
            query('select * from i2b2_bc_summary',
                  lc).to_excel(writer, 'I2B2 BC')
            query('select * from demographic_summary',
                  lc).to_excel(writer, 'DEMOGRAPHIC')
            query('select * from encounters_per_visit_patient',
                  lc).to_excel(writer, 'ENCOUNTER IID')
            query('select * from id_counts_by_table',
                  lc).to_excel(writer, 'ENROLLMENT ID')
            query('select * from dx_by_enc_type',
                  lc).to_excel(writer, 'DIAGNOSIS IVA')
            query('select * from px_per_enc_by_type',
                  lc).to_excel(writer, 'PROCEDURES IVB')
            query('select * from dispensing_trend_chart',
                  lc).to_excel(writer, 'DISPENSING IF')
            query('select * from id_counts_ranges_death',
                  lc).to_excel(writer, 'DEATH ID')
            query('select * from iic_illogical_dates',
                  lc).to_excel(writer, 'DEATH IIC')
            self.uploads(lc).to_excel(writer, 'I2B2 Tasks')
            query('select * from harvest',
                  lc).transpose().to_excel(writer, 'Harvest')
        writer.save()

    def inclusion_criteria(self, lc: et.LoggedConnection) -> pd.DataFrame:
        # ISSUE: sync with build_cohort.sql?
        # ISSUE: left out 'NAACCR|400:C509'
        return rif_etl.read_sql_step(
            '''
        select concept_cd, min(name_char) name_char
        from blueherondata_kumc_calamus.concept_dimension
        where concept_cd in (
          'SEER_SITE:26000',
          'NAACCR|400:C500',
          'NAACCR|400:C501',
          'NAACCR|400:C502',
          'NAACCR|400:C503',
          'NAACCR|400:C504',
          'NAACCR|400:C505',
          'NAACCR|400:C506',
          'NAACCR|400:C507',
          'NAACCR|400:C508',
          'NAACCR|400:C509'
        )
        group by concept_cd
        order by concept_cd
        ''', lc).set_index('concept_cd')

    def cohorts(self, lc: et.LoggedConnection) -> pd.DataFrame:
        cohorts = rif_etl.read_sql_step(
            '''
                select site_schema, result_instance_id, start_date, task_id, count(distinct patient_num)
                from site_cohorts
                group by site_schema, result_instance_id, start_date, task_id
                order by start_date desc
            ''', lc)
        cohorts = cohorts.append(
            rif_etl.read_sql_step(
                '''
                select count(distinct site_schema) site_schema
                     , max(result_instance_id) result_instance_id
                     , max(start_date) start_date
                     , 'Total' task_id
                     , count(distinct patient_num)
                 from site_cohorts''', lc)).set_index('task_id')
        return cohorts

    def uploads(self, lc: et.LoggedConnection) -> pd.DataFrame:
        return rif_etl.read_sql_step(
            '''
                 select *
                from upload_status up
                where load_status like 'OK%' and ((
                      loaded_record > 0
                  and substr(transform_name, -11) in (
                    select distinct task_id from site_cohorts
                  )
                ) or (
                  upload_label like 'migrate obs%'
                ) or (
                  message like 'UP#%' and upload_label like '% #1 of 1%'
                ))
                order by load_date desc
                ''', lc).set_index('upload_id')
class NAACCR_FlatFile(ManualTask):
    """A NAACCR flat file is determined by the registry, export date,
    and version.
    """
    naaccrRecordVersion = pv.IntParam(default=180)
    dateCaseReportExported = pv.DateParam()
    npiRegistryId = pv.StrParam()
    testData = pv.BoolParam(default=False, significant=False)
    flat_file = pv.PathParam(significant=False)
    record_qty_min = pv.IntParam(significant=False, default=1)

    def check_version_param(self) -> None:
        """Only version 18 (180) is currently supported.
        """
        if self.naaccrRecordVersion != 180:
            raise NotImplementedError()

    def complete(self) -> bool:
        with task_action(self, 'complete') as ctx:
            result = self.complete_action()
            ctx.add_success_fields(result=result)
            return result

    def complete_action(self) -> bool:
        """Check the first record, assuming all the others have
        the same export date and registry NPI.
        """
        self.check_version_param()

        with self.flat_file.open() as records:
            record0 = records.readline()
            qty = 1 + sum(1 for _ in records.readlines())
        log.info('record qty: %d (> %d? %s)', qty, self.record_qty_min,
                 qty >= self.record_qty_min)

        vOk = self._checkItem(record0, 'naaccrRecordVersion',
                              str(self.naaccrRecordVersion))
        regOk = self._checkItem(record0, 'npiRegistryId', self.npiRegistryId)
        dtOk = self._checkItem(record0, 'dateCaseReportExported',
                               self.dateCaseReportExported.strftime('%Y%m%d'))

        if vOk and regOk and dtOk and qty >= self.record_qty_min:
            return True
        else:
            if self.testData:
                log.warn('ignoring failed FlatFile check')
                return True
            return False

    @classmethod
    def _checkItem(cls, record: str, naaccrId: str, expected: str) -> bool:
        '''
        >>> npi = '1234567890'
        >>> record0 = ' ' * 19 + npi
        >>> NAACCR_FlatFile._checkItem(record0, 'npiRegistryId', npi)
        True
        >>> NAACCR_FlatFile._checkItem(record0, 'npiRegistryId', 'XXX')
        False
        '''
        itemDef = tr_ont.NAACCR1.itemDef(naaccrId)
        [startColumn, length
         ] = [int(itemDef.attrib[it]) for it in ['startColumn', 'length']]
        startColumn -= 1
        actual = record[startColumn:startColumn + length]
        if actual != expected:
            log.warn('%s: expected %s [%s:%s] = {%s} but found {%s}',
                     cls.__name__, naaccrId, startColumn - 1,
                     startColumn + length, expected, actual)
        return actual == expected
class NAACCR_Ontology1(JDBCTask):
    table_name = pv.StrParam(default="NAACCR_ONTOLOGY")
    who_cache = pv.PathParam()
    z_design_id = pv.StrParam(
        default='2019-12-16 pystdlib %s' %
        _stable_hash(tr_ont.NAACCR_I2B2.ont_script.code),
        description='''
        mnemonic for latest visible change to output.
        Changing this causes task_id to change, which
        ensures the ontology gets rebuilt if necessary.
        '''.strip(),
    )
    naaccr_version = pv.IntParam(default=18)  # ISSUE: ignored?
    jdbc_driver_jar = pv.StrParam(significant=False)

    # based on custom_meta
    col_to_type = dict(
        c_hlevel='int',
        c_fullname='varchar(700)',
        c_name='varchar(2000)',
        c_visualattributes='varchar(3)',
        c_totalnum='int',
        c_basecode='varchar(50)',
        c_dimcode='varchar(700)',
        c_tooltip='varchar(900)',
        update_date='date',
        sourcesystem_cd='varchar(50)',
    )
    coltypes = ','.join(f'{name} {ty}' for (name, ty) in col_to_type.items())

    @property
    def version_name(self) -> str:
        """version info that fits in an i2b2 name (50 characters)
        """
        task_hash = self.task_id.split('_')[
            -1]  # hmm... luigi doesn't export this
        return f'v{self.naaccr_version}-{task_hash}'

    @property
    def task_hash(self) -> str:
        return self.task_id.split('_')[-1]  # hmm... luigi doesn't export this

    def output(self) -> JDBCTableTarget:
        query = fr"""
          (select 1 from {self.table_name}
           where c_fullname = '{tr_ont.NAACCR_I2B2.top_folder}'
           and c_basecode = '{self.version_name}')
        """
        return JDBCTableTarget(self, query)

    @property
    def classpath(self) -> str:
        return self.jdbc_driver_jar

    @property
    def __password(self) -> str:
        from os import environ  # ISSUE: ambient
        return environ[self.passkey]

    def account(self) -> td.Account:
        from subprocess import Popen  # ISSUE: AMBIENT
        return td.Account('DEID',
                          self.user,
                          self.__password,
                          Popen,
                          url=self.db_url,
                          driver=self.driver)

    def run(self) -> None:
        conn = connect_mem(':memory:', detect_types=PARSE_COLNAMES)
        spark = DBSession(conn)
        update_date = dt.datetime.strptime(self.z_design_id[:10],
                                           '%Y-%m-%d').date()
        terms = tr_ont.NAACCR_I2B2.ont_view_in(spark,
                                               who_cache=self.who_cache,
                                               task_hash=self.task_hash,
                                               update_date=update_date)
        cdw = self.account()
        cdw.wr(self.table_name, td.case_fold(terms))
Beispiel #26
0
class BuildCohort(et.UploadTask):
    script = Script.build_cohort
    site_star_schema = pv.StrParam(description='DATA_KUMC or DAT_MCW etc.')
    inclusion_concept_cd = pv.StrParam(default='SEER_SITE:26000')
    dx_date_min = DateParam(default=datetime(2011, 1, 1, 0, 0, 0))
    _keys = None
    _key_seqs = ['QT_SQ_QRI_QRIID', 'QT_SQ_QM_QMID',
                 'QT_SQ_QRI_QRIID']  # ISSUE: same sequence?

    @property
    def source(self) -> 'SiteI2B2':
        return SiteI2B2(star_schema=self.site_star_schema)

    @property
    def variables(self) -> Environment:
        return dict(
            I2B2_STAR=self.project.star_schema,
            I2B2_STAR_SITE=self.source.star_schema,
        )

    def script_params(self, conn: et.LoggedConnection) -> Params:
        upload_params = et.UploadTask.script_params(self, conn)

        [result_instance_id, query_master_id,
         query_instance_id] = self._allocate_keys(conn)
        return dict(upload_params,
                    task_id=self.task_id,
                    inclusion_concept_cd=self.inclusion_concept_cd,
                    dx_date_min=self.dx_date_min,
                    result_instance_id=result_instance_id,
                    query_instance_id=query_instance_id,
                    query_master_id=query_master_id,
                    query_name='%s: %s' %
                    (self.site_star_schema, query_master_id),
                    user_id=self.task_family)

    def _allocate_keys(self, conn: et.LoggedConnection) -> List[int]:
        if self._keys is not None:
            return self._keys
        self._keys = out = []  # type: List[int]
        for seq_name in self._key_seqs:
            x = conn.execute(
                "select {schema}.{seq_name}.nextval from dual".format(
                    schema=self.project.star_schema,
                    seq_name=seq_name)).scalar()
            if not isinstance(x, int):
                raise ValueError(x)
            out.append(x)
        return out

    def loaded_record(self, conn: et.LoggedConnection, _bulk_rows: int) -> int:
        assert self._keys
        [result_instance_id, _, _] = self._keys
        size = conn.execute(
            '''
            select set_size from {i2b2_star}.qt_query_result_instance
            where result_instance_id = :result_instance_id
            '''.format(i2b2_star=self.project.star_schema),
            dict(result_instance_id=result_instance_id)).scalar()
        if not isinstance(size, int):
            raise ValueError(size)
        return size

    def result_instance_id(self, conn: et.LoggedConnection) -> int:
        '''Fetch patient set id after task has run.
        '''
        rid = conn.execute('''
            select max(result_instance_id)
            from {i2b2_star}.qt_query_result_instance
            where set_size > 0 and description = :task_id
            '''.format(i2b2_star=self.project.star_schema),
                           params=dict(task_id=self.task_id)).scalar()
        if not isinstance(rid, int):
            raise ValueError(rid)
        return rid
class JDBCTask(luigi.Task):
    db_url = pv.StrParam(description='see client.cfg', significant=False)
    driver = pv.StrParam(default="oracle.jdbc.OracleDriver", significant=False)
    user = pv.StrParam(description='see client.cfg', significant=False)
    passkey = pv.StrParam(description='see client.cfg', significant=False)

    @property
    @abstractmethod
    def classpath(self) -> str:
        pass

    @property
    def __password(self) -> str:
        from os import environ  # ISSUE: ambient
        return environ[self.passkey]

    @contextmanager
    def connection(self, action_type: str) -> Iterator[Connection]:
        with el.start_action(action_type=action_type,
                             url=self.db_url,
                             driver=self.driver,
                             user=self.user):
            with TheJVM.borrow(self.classpath) as jvm:
                jvm.java.lang.Class.forName(self.driver)
                conn = jvm.java.sql.DriverManager.getConnection(
                    self.db_url, self.user, self.__password)
                try:
                    # avoid: JdbcUtils: Requested isolation level 1 is not supported;
                    #        falling back to default isolation level 2
                    conn.setTransactionIsolation(2)
                    yield conn
                finally:
                    conn.close()

    # ISSUE: dates are not JSON serializable, so log_call doesn't grok.
    @el.log_call(include_args=['fname', 'variables'])
    def run_script(self,
                   conn: Connection,
                   fname: str,
                   sql_code: str,
                   variables: Opt[Environment] = None,
                   script_params: Opt[Params] = None) -> None:
        '''Run script inside a LoggedConnection event.

        @param run_vars: variables to define for this run
        @param script_params: parameters to bind for this run

        To see how a script can ignore errors, see :mod:`script_lib`.
        '''
        ignore_error = False
        run_params = dict(script_params or {}, task_id=self.task_id)

        for line, _comment, statement in SqlScript.each_statement(
                sql_code, variables):
            try:
                ignore_error = self.execute_statement(conn, fname, line,
                                                      statement, run_params,
                                                      ignore_error)
            except Exception as exc:
                # ISSUE: connection label should show sid etc.
                err = SqlScriptError(exc, fname, line, statement, self.task_id)
                if ignore_error:
                    log.warning('%(event)s: %(error)s',
                                dict(event='ignore', error=err))
                else:
                    raise err from None

    @classmethod
    @el.log_call(include_args=['fname', 'line', 'ignore_error'])
    def execute_statement(cls, conn: Connection, fname: str, line: int,
                          statement: SQL, params: Params,
                          ignore_error: bool) -> bool:
        '''Log and execute one statement.
        '''
        sqlerror = SqlScript.sqlerror(statement)
        if sqlerror is not None:
            return sqlerror
        with cls.prepared(conn, statement, params) as stmt:
            stmt.execute()
        return ignore_error

    @classmethod
    @contextmanager
    def prepared(cls, conn: Connection, sql: str,
                 params: Params) -> Iterator[PreparedStatement]:
        sqlq, values = to_qmark(sql, params)
        stmt = conn.prepareStatement(sqlq)
        for ix0, value in enumerate(values):
            ix1 = ix0 + 1
            if value is None:
                VARCHAR = 12
                stmt.setNull(ix1, VARCHAR)
            if isinstance(value, int):
                stmt.setInt(ix1, value)
            elif isinstance(value, str):
                stmt.setString(ix1, value)
            elif isinstance(value, dt.date):
                dv = TheJVM.sql_date(value)
                stmt.setDate(ix1, dv)
            elif isinstance(value, dt.datetime):
                tv = TheJVM.sql_timestamp(value)
                stmt.setTimestamp(ix1, tv)
        try:
            with el.start_action(action_type='prepared statement',
                                 sql=sql.strip(),
                                 sqlq=sqlq.strip(),
                                 params=', '.join(params.keys()),
                                 values=_json_ok(values)):
                yield stmt
        finally:
            stmt.close()