コード例 #1
0
class PredictorManager:
    def __init__(self, alert_message_queue, event_loop):
        self.predictors = {}
        self.alert_message_queue = alert_message_queue
        self.predict_task_futures = {}
        self.loop = event_loop
        self.cloudwatch_logger = Cloudwatch()
        self.model = os.getenv('suppression_model', 'trews')
        # Start monitoring task
        self.loop.create_task(self.monitor_predictors())

    async def monitor_predictors(self):
        ''' Monitor predictors and log current status in cloudwatch every 30 seconds '''

        while True:
            # Build list of tuples of cloudwatch info (name, value, unit)
            metric_tuples = []

            # Get overall predictor info
            metric_tuples.append(
                ('push_num_predictors', int(len(self.predictors)), 'Count'))

            # Send individual predictor info to cloudwatch
            for pred_id, pred in self.predictors.items():
                metric_tuples += [
                    ('push_predictor_{}_{}_{}_status'.format(*pred_id),
                     STATUS_DICT[pred.status], 'None'),
                ]
            logging.info("cloudwatch metrics: {}".format(metric_tuples))
            # Send all info to cloudwatch
            self.cloudwatch_logger.push_many(
                dimension_name='LMCPredictors',
                metric_names=[metric[0] for metric in metric_tuples],
                metric_values=[metric[1] for metric in metric_tuples],
                metric_units=[metric[2] for metric in metric_tuples])
            await asyncio.sleep(1)

    async def register(self, reader, writer, msg):
        ''' Register connection from a predictor '''

        # Create predictor object
        pred = Predictor(reader, writer, msg['status'], msg['node_index'],
                         msg['partition_index'], msg['model_type'],
                         msg['is_active'], msg['ip_address'], self)

        # Cancel any existing predictor with same id
        if pred.id in self.predictors:
            self.predictors[pred.id].shutdown = True

        # Save predictor in data structure
        self.predictors[pred.id] = pred
        logging.info("Registered {}".format(pred))

        # Start listener loop
        return await pred.listen(self.alert_message_queue)

    def get_partition_ids(self):
        return set([p.partition_index for p in self.predictors.values()])

    def get_model_types(self):
        return set([p.model_type for p in self.predictors.values()])

    def cancel_predict_tasks(self, job_id):
        ''' Cancel the existing tasks for previous ETL '''
        logging.info("cancel the existing tasks for previous ETL")
        for future in self.predict_task_futures.get(job_id, []):
            future.cancel()
            logging.info("{} cancelled".format(future))

    def create_predict_tasks(self, hosp, time, job_id, active_encids=None):
        ''' Start all predictors '''
        logging.info("Starting all predictors for ETL {} {} {} {}".format(
            hosp, time, job_id, active_encids))
        self.predict_task_futures[job_id] = []
        for pid in self.get_partition_ids():
            for model in self.get_model_types():
                future = asyncio.ensure_future(self.run_predict(
                    pid, model, hosp, time, job_id, active_encids),
                                               loop=self.loop)
                self.predict_task_futures[job_id].append(future)
        logging.info("Started {} predictors".format(
            len(self.predict_task_futures[job_id])))

    async def run_predict(self,
                          partition_id,
                          model_type,
                          hosp,
                          time,
                          job_id,
                          active_encids,
                          active=True):
        ''' Start a predictor for a given partition id and model '''
        backoff = 1

        # Start the predictor
        logging.info(
            "start the predictors for {} at {}: partition_id {} model_type {}".
            format(hosp, time, partition_id, model_type))
        while True:
            pred = self.predictors.get((partition_id, model_type, active))
            if pred and pred.status != 'DEAD':
                try:
                    predictor_started = await pred.start_predictor(
                        hosp, time, job_id, active_encids)
                    break
                except (ConnectionRefusedError) as e:
                    err = e
            else:
                err = '{} dead'.format(
                    predictor_str(partition_id, model_type, active))

            if self.model == 'trews-jit':
                return
            else:
                # Log reason for error
                logging.error("{} -- trying {} predictor {}".format(
                    err, 'backup' if active else 'active',
                    humanize.naturaltime(dt.datetime.now() +
                                         dt.timedelta(seconds=backoff))))

                # Switch to backup if active (and vice versa)
                active = not active
                await asyncio.sleep(backoff)
                backoff = backoff * 2 if backoff < 64 else backoff

        # Check status
        if predictor_started == False:
            return None
コード例 #2
0
class AlertServer:
    def __init__(self,
                 event_loop,
                 alert_server_port=31000,
                 alert_dns='0.0.0.0'):
        self.db = Database()
        self.loop = event_loop
        self.alert_message_queue = asyncio.Queue(loop=event_loop)
        self.predictor_manager = PredictorManager(self.alert_message_queue,
                                                  self.loop)
        self.alert_server_port = alert_server_port
        self.alert_dns = alert_dns
        self.channel = os.getenv('etl_channel', 'on_opsdx_dev_etl')
        self.suppression_tasks = {}
        self.model = os.getenv('suppression_model', 'trews')
        self.TREWS_ETL_SUPPRESSION = int(os.getenv('TREWS_ETL_SUPPRESSION', 0))
        self.notify_web = int(os.getenv('notify_web', 0))
        self.lookbackhours = int(os.getenv('TREWS_ETL_HOURS', 24))
        self.nprocs = int(os.getenv('nprocs', 2))
        self.hospital_to_predict = os.getenv('hospital_to_predict', 'HCGH')
        self.push_based = bool(os.getenv('push_based', 0))
        self.workspace = os.getenv('workspace', 'workspace')
        self.cloudwatch_logger = Cloudwatch()
        self.job_status = {}

    async def async_init(self):
        self.db_pool = await self.db.get_connection_pool()

    async def convert_enc_ids_to_pat_ids(self, enc_ids):
        ''' Return a list of pat_ids from their corresponding enc_ids '''
        async with self.db_pool.acquire() as conn:
            sql = '''
      SELECT distinct pat_id FROM pat_enc where enc_id
      in ({})
      '''.format(','.join([str(i) for i in enc_ids]))
            pat_ids = await conn.fetch(sql)
            return pat_ids

    async def suppression(self, pat_id, tsp):
        ''' Alert suppression task for a single patient
        and notify frontend that the patient has updated'''
        async def criteria_ready(conn, pat_id, tsp):
            '''
      criteria is ready when
      1. criteria is updated after tsp
      2. no new data in criteria_meas within lookbackhours (ETL will not update criteria)
      '''
            sql = '''
      SELECT count(*) > 0
        or (select count(*) = 0 from criteria_meas m
            where m.pat_id = '{pat_id}' and now() - tsp < (select value::interval from parameters where name = 'lookbackhours')) ready
       FROM criteria where pat_id = '{pat_id}'
      and update_date > '{tsp}'::timestamptz
      '''.format(pat_id=pat_id, tsp=tsp)
            cnt = await conn.fetch(sql)

            return cnt[0]['ready']

        async with self.db_pool.acquire() as conn:
            n = 0
            N = 60

            logging.info("enter suppression task for {} - {}".format(
                pat_id, tsp))
            while not await criteria_ready(conn, pat_id, tsp):
                await asyncio.sleep(10)
                n += 1
                logging.info("retry criteria_ready {} times for {}".format(
                    n, pat_id))
                if n >= 60:
                    break
            if n < 60:
                logging.info("criteria is ready for {}".format(pat_id))
                sql = '''
        select update_suppression_alert('{pat_id}', '{channel}', '{model}', '{notify}');
        '''.format(pat_id=pat_id,
                   channel=self.channel,
                   model=self.model,
                   notify=self.notify_web)
                logging.info("suppression sql: {}".format(sql))
                await conn.fetch(sql)
                logging.info(
                    "generate suppression alert for {}".format(pat_id))
            else:
                logging.info("criteria is not ready for {}".format(pat_id))

    def garbage_collect_suppression_tasks(self, hosp):
        for task in self.suppression_tasks.get(hosp, []):
            task.cancel()
        self.suppression_tasks[hosp] = []

    async def alert_queue_consumer(self):
        '''
    Check message queue and process messages
    '''
        logging.info("alert_queue_consumer started")
        while True:
            msg = await self.alert_message_queue.get()
            logging.info("alert_message_queue recv msg: {}".format(msg))
            # Predictor finished
            if msg.get('type') == 'FIN':
                if self.model == 'lmc' or self.model == 'trews-jit':
                    if self.TREWS_ETL_SUPPRESSION == 1:
                        suppression_future = asyncio.ensure_future(
                            self.run_suppression(msg), loop=self.loop)
                    elif self.TREWS_ETL_SUPPRESSION == 2:
                        suppression_future = asyncio.ensure_future(
                            self.run_suppression_mode_2(msg), loop=self.loop)
                else:
                    logging.error("Unknown model: {}".format(self.model))
                # self.suppression_tasks[msg['hosp']].append(suppression_future)
                # logging.info("create {model} suppression task for {}".format(self.model,msg['hosp']))
        logging.info("alert_queue_consumer quit")

    async def suppression(self, pat_id, tsp):
        ''' Alert suppression task for a single patient
        and notify frontend that the patient has updated'''

    async def run_suppression_mode_2(self, msg):
        t_fin = dt.datetime.now()
        # if msg['hosp']+msg['time'] in self.job_status:
        if msg['job_id'] in self.job_status:
            t_start = self.job_status[msg['job_id']]['t_start']
            self.cloudwatch_logger.push_many(
                dimension_name='AlertServer',
                metric_names=[
                    'prediction_time_{}{}'.format(
                        msg['hosp'], '_push' if self.push_based else ''),
                    'prediction_enc_cnt_in_{}{}'.format(
                        msg['hosp'], '_push' if self.push_based else ''),
                    'prediction_enc_cnt_out_{}{}'.format(
                        msg['hosp'], '_push' if self.push_based else '')
                ],
                metric_values=[(t_fin - t_start).total_seconds(),
                               len(msg['enc_ids']),
                               len(msg['predicted_enc_ids'])],
                metric_units=['Seconds', 'Count', 'Count'])
        logging.info("start to run suppression mode 2 for msg {}".format(msg))
        tsp = msg['time']
        enc_id_str = ','.join([str(i) for i in msg['enc_ids'] if i])
        hospital = msg['hosp']
        logging.info("received FIN for enc_ids: {}".format(enc_id_str))
        # calculate criteria here
        # NOTE: I turst the enc_ids from FIN msg
        async with self.db_pool.acquire() as conn:
            if self.notify_web:
                if self.push_based:
                    job_id = msg['job_id']
                    await self.calculate_criteria_enc(conn, msg['enc_ids'])
                    sql = '''
          with pats as (
            select p.enc_id, p.pat_id from pat_enc p
            where p.enc_id in ({enc_ids})
          ),
          refreshed as (
            insert into refreshed_pats (refreshed_tsp, pats)
            select now(), jsonb_agg(pat_id) from pats
            returning id
          )
          select pg_notify('{channel}', 'invalidate_cache_batch:' || id || ':' || '{model}') from refreshed;
          '''.format(channel=self.channel,
                     model=self.model,
                     enc_ids=enc_id_str)
                else:
                    await self.calculate_criteria_hospital(conn, hospital)
                    sql = '''
          with pats as (
            select p.enc_id, p.pat_id from pat_enc p
            inner join get_latest_enc_ids('{hosp}') e on p.enc_id = e.enc_id
          ),
          refreshed as (
            insert into refreshed_pats (refreshed_tsp, pats)
            select now(), jsonb_agg(pat_id) from pats
            returning id
          )
          select pg_notify('{channel}', 'invalidate_cache_batch:' || id || ':' || '{model}') from refreshed;
        '''.format(channel=self.channel,
                    model=self.model,
                    enc_id_str=enc_id_str,
                    hosp=msg['hosp'])
                logging.info("trews alert sql: {}".format(sql))
                await conn.fetch(sql)
                logging.info("generated trews alert for {}".format(hospital))
        logging.info(
            "complete to run suppression mode 2 for msg {}".format(msg))
        t_end = dt.datetime.now()
        if msg['job_id'] in self.job_status:
            t_start = self.job_status[msg['job_id']]['t_start']
            self.cloudwatch_logger.push_many(
                dimension_name='AlertServer',
                metric_names=[
                    'e2e_time_{}{}'.format(msg['hosp'],
                                           '_push' if self.push_based else ''),
                    'criteria_time_{}{}'.format(
                        msg['hosp'], '_push' if self.push_based else ''),
                ],
                metric_values=[
                    (t_end - t_start).total_seconds(),
                    (t_end - t_fin).total_seconds(),
                ],
                metric_units=['Seconds', 'Seconds'])
            self.job_status.pop(msg['job_id'], None)

    async def run_suppression(self, msg):
        # Wait for Advance Criteria Snapshot to finish and then start generating notifications
        logging.info("start to run suppression for msg {}".format(msg))
        tsp = msg['time']
        pat_ids = await self.convert_enc_ids_to_pat_ids(msg['enc_ids'])
        pats_str = ','.join([str(i) for i in pat_ids])
        hospital = msg['hosp']
        logging.info("received FIN for enc_ids: {}".format(pats_str))

        async def criteria_ready(conn, enc_ids, tsp):
            '''
      criteria is ready when
      1. criteria is updated after tsp
      2. no new data in criteria_meas within lookbackhours (ETL will not update criteria)
      '''
            sql = '''
      with pats as (
        select distinct enc_id from criteria where enc_id in ({enc_ids})
      ),
      updated_pats as (
        select distinct enc_id from criteria where enc_id in ({enc_ids}) and update_date >= '{tsp}'::timestamptz
      )
      SELECT * from pats except select * from updated_pats
      '''.format(enc_ids=enc_ids, tsp=tsp)
            cnt = await conn.fetch(sql)
            if cnt is None or len(cnt) == 0:
                logging.info("criteria is ready")
                return True
            else:
                logging.info("criteria is not ready ({})".format(len(cnt)))
                return False

        async with self.db_pool.acquire() as conn:
            n = 0
            N = 60

            logging.info("enter suppression task for {}".format(msg))
            while not await criteria_ready(conn, pats_str, tsp):
                await asyncio.sleep(10)
                n += 1
                logging.info("retry criteria_ready {} times for {}".format(
                    n, pats_str))
                if n >= 60:
                    break
            if n < 60:
                if self.notify_web:
                    sql = '''
          with pats as (
            select enc_id, pat_id from pat_enc where enc_id in ({pats})
          ),
          alerts as (
            select update_suppression_alert(enc_id, '{channel}', '{model}', 'false') from pats),
          refreshed as (
            insert into refreshed_pats (refreshed_tsp, pats)
            select now(), jsonb_agg(pat_id) from pats
            returning id
          )
          select pg_notify('{channel}', 'invalidate_cache_batch:' || id || ':' || '{model}') from refreshed;
          '''.format(channel=self.channel, model=self.model, pats=pats_str)
                else:
                    sql = '''
          with pats as (
            select enc_id from pat_enc where enc_id in ({pats})
          )
          select update_suppression_alert(enc_id, '{channel}', '{model}', 'false') from pats)
          '''.format(channel=self.channel, model=self.model, pats=pats_str)
                    logging.info("lmc suppression sql: {}".format(sql))
                    await conn.fetch(sql)
                    logging.info(
                        "generate suppression alert for {}".format(hospital))
            else:
                logging.info("criteria is not ready for {}".format(pats_str))

    async def distribute_calculate_criteria(self, conn, job_id):
        server = 'dev_db' if 'dev' in self.channel else 'prod_db'
        hospital = None
        if 'hcgh' in job_id:
            hospital = 'HCGH'
        elif 'bmc' in job_id:
            hospital = 'BMC'
        elif 'jhh' in job_id:
            hospital = 'JHH'
        else:
            logging.error("Invalid job id: {}".format(job_id))
        if hospital:
            sql = "select garbage_collection('{}');".format(hospital)
            logging.info("calculate_criteria sql: {}".format(sql))
            await conn.fetch(sql)
            sql = '''
      select distribute_advance_criteria_snapshot_for_job('{server}', {hours}, '{job_id}', {nprocs});
      '''.format(server=server,
                 hours=self.lookbackhours,
                 job_id=job_id,
                 nprocs=self.nprocs)
            logging.info("calculate_criteria sql: {}".format(sql))
            await conn.fetch(sql)

    async def distribute_calculate_criteria_hospital(self, conn, hospital):
        server = 'dev_db' if 'dev' in self.channel else 'prod_db'
        sql = "select garbage_collection('{}');".format(hospital)
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)
        sql = '''
    select distribute_advance_criteria_snapshot_for_online_hospital('{server}', '{hospital}', {nprocs});
    '''.format(server=server, hospital=hospital, nprocs=self.nprocs)
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)

    async def calculate_criteria_hospital(self, conn, hospital):
        sql = "select garbage_collection('{}', '{}');".format(
            hospital, self.workspace)
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)
        sql = '''
    select advance_criteria_snapshot(enc_id) from get_latest_enc_ids('{hospital}');
    '''.format(hospital=hospital)
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)

    async def calculate_criteria_enc(self, conn, enc_ids):
        sql = "select garbage_collection_enc_ids(array[{}],'{}')".format(
            ','.join([str(enc_id) for enc_id in enc_ids]), self.workspace)
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)
        sql = 'select advance_criteria_snapshot_enc_ids(array[{}])'.format(
            ','.join([str(enc_id) for enc_id in enc_ids]))
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)
        logging.info("complete calculate_criteria_enc")

    async def calculate_criteria_push(self, conn, job_id, excluded=None):
        sql = '''
    select garbage_collection(enc_id, '{workspace}')
    from (select distinct enc_id from {workspace}.cdm_t
          where job_id = '{job_id}' {where}) e;
    '''.format(workspace=self.workspace, job_id=job_id, where=excluded)
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)
        sql = '''
    select advance_criteria_snapshot_batch(
      'select distinct enc_id from {workspace}.cdm_t where job_id = ''{job_id}'' {where}'
    );
    '''.format(workspace=self.workspace, job_id=job_id, where=excluded)
        # sql = '''
        # select advance_criteria_snapshot(enc_id)
        # from (select distinct enc_id from {workspace}.cdm_t
        #       where job_id = '{job_id}' {where}) e;
        # '''.format(workspace=self.workspace, job_id=job_id, where=excluded)
        logging.info("calculate_criteria sql: {}".format(sql))
        await conn.fetch(sql)
        logging.info("complete calculate_criteria_enc")

    async def get_enc_ids_to_predict(self, job_id):
        async with self.db_pool.acquire() as conn:
            # rule to select predictable enc_ids:
            # 1. has changes delta twf > 0
            # 2. adult
            # 3. HCGH patients only
            sql = '''
        select distinct t.enc_id
        from {workspace}.cdm_t t
        inner join cdm_twf twf on t.enc_id = twf.enc_id
        left join {workspace}.{job_id}_cdm_twf twf_delta on twf_delta.enc_id = t.enc_id
        left join {workspace}.{job_id}_cdm_twf_del twf_del on twf_del.enc_id = t.enc_id
        inner join cdm_s s on twf.enc_id = s.enc_id
        inner join cdm_s s2 on twf.enc_id = s2.enc_id
        where s.fid = 'age' and s.value::float >= 18.0
        and s2.fid = 'hospital' and s2.value = 'HCGH'
        and job_id = '{job_id}' and (twf_delta.enc_id is not null or twf_del.enc_id is not null)
      '''.format(workspace=self.workspace, job_id=job_id)
            res = await conn.fetch(sql)
            predict_enc_ids = [row[0] for row in res]
            return predict_enc_ids

    async def run_trews_alert(self, job_id, hospital, excluded_enc_ids=None):
        async with self.db_pool.acquire() as conn:
            if self.push_based and hospital == 'PUSH':
                # calculate criteria here
                excluded = ''
                if excluded_enc_ids:
                    excluded = 'and enc_id not in ({})'.format(','.join(
                        [str(id) for id in excluded_enc_ids]))
                await self.calculate_criteria_push(conn,
                                                   job_id,
                                                   excluded=excluded)
                if self.notify_web:
                    sql = '''
          with pats as (
            select e.enc_id, p.pat_id from (
              select distinct enc_id from {workspace}.cdm_t
              where job_id = '{job_id}'
              {where}
            ) e
            inner join pat_enc p on e.enc_id = p.enc_id
          ),
          refreshed as (
            insert into refreshed_pats (refreshed_tsp, pats)
            select now(), jsonb_agg(pat_id) from pats
            returning id
          )
          select pg_notify('{channel}', 'invalidate_cache_batch:' || id || ':' || '{model}') from refreshed;
          '''.format(channel=self.channel,
                     model=self.model,
                     where=excluded,
                     workspace=self.workspace,
                     job_id=job_id)
                    logging.info("trews alert sql: {}".format(sql))
                    await conn.fetch(sql)
                    logging.info(
                        "generated trews alert for {} without prediction".
                        format(hospital))
            elif self.TREWS_ETL_SUPPRESSION == 2:
                # calculate criteria here
                await self.calculate_criteria_hospital(conn, hospital)
                if self.notify_web:
                    sql = '''
          with pats as (
            select e.enc_id, p.pat_id from get_latest_enc_ids('{hospital}') e inner join pat_enc p on e.enc_id = p.enc_id
          ),
          refreshed as (
            insert into refreshed_pats (refreshed_tsp, pats)
            select now(), jsonb_agg(pat_id) from pats
            returning id
          )
          select pg_notify('{channel}', 'invalidate_cache_batch:' || id || ':' || '{model}') from refreshed;
          '''.format(channel=self.channel, model=self.model, hospital=hospital)
                    logging.info("trews alert sql: {}".format(sql))
                    await conn.fetch(sql)
                    logging.info(
                        "generated trews alert for {}".format(hospital))
            elif self.TREWS_ETL_SUPPRESSION == 1:
                if self.notify_web:
                    sql = '''
          with pats as (
            select e.enc_id, p.pat_id from get_latest_enc_ids('{hospital}') e inner join pat_enc p on e.enc_id = p.enc_id
          ),
          alerts as (
            select update_suppression_alert(enc_id, '{channel}', '{model}', 'false') from pats),
          refreshed as (
            insert into refreshed_pats (refreshed_tsp, pats)
            select now(), jsonb_agg(pat_id) from pats
            returning id
          )
          select pg_notify('{channel}', 'invalidate_cache_batch:' || id || ':' || '{model}') from refreshed;
            '''.format(channel=self.channel,
                       model=self.model,
                       hospital=hospital)
                else:
                    sql = '''
          select update_suppression_alert(enc_id, '{channel}', '{model}', 'false') from
          (select distinct t.enc_id from cdm_t t
          inner join get_latest_enc_ids('{hospital}') h on h.enc_id = t.enc_id
          where now() - tsp < (select value::interval from parameters where name = 'lookbackhours')) sub;
            '''.format(channel=self.channel,
                       model=self.model,
                       hospital=hospital)
                logging.info("trews suppression sql: {}".format(sql))
                await conn.fetch(sql)
                logging.info(
                    "generate trews suppression alert for {}".format(hospital))

    async def connection_handler(self, reader, writer):
        ''' Alert server connection handler '''
        addr = writer.transport.get_extra_info('peername')
        sock = writer.transport.get_extra_info('socket')

        if not addr:
            logging.error(
                'Connection made without a valid remote address, (Timeout %s)'
                % str(sock.gettimeout()))
            return
        else:
            logging.debug('Connection from %s (Timeout %s)' %
                          (str(addr), str(sock.gettimeout())))

        # Get the message that started this callback function
        message = await protocol.read_message(reader, writer)
        logging.info("connection_handler: recv msg from {} type {}".format(
            message.get('from'), message.get('type')))
        if message.get('from') == 'predictor':
            return await self.predictor_manager.register(
                reader, writer, message)

        elif message.get('type') == 'ETL':
            self.cloudwatch_logger.push_many(
                dimension_name='AlertServer',
                metric_names=['etl_done_{}'.format(message['hosp'])],
                metric_values=[1],
                metric_units=['Count'])
            # self.job_status[message['hosp'] + message['time']] = {
            #   'msg': message, 't_start': dt.datetime.now()
            # }
            if self.model == 'lmc' or self.model == 'trews-jit':
                job_id_items = message['job_id'].split('_')
                t_start = parser.parse(job_id_items[-1] if len(job_id_items) ==
                                       4 else job_id_items[-2])
                if self.push_based:
                    # create predict task for predictor
                    predict_enc_ids = await self.get_enc_ids_to_predict(
                        message['job_id'])
                    if predict_enc_ids:
                        self.job_status[message['job_id']] = {
                            't_start': t_start
                        }
                        self.predictor_manager.cancel_predict_tasks(
                            job_id=message['job_id'])
                        self.predictor_manager.create_predict_tasks(
                            hosp=message['hosp'],
                            time=message['time'],
                            job_id=message['job_id'],
                            active_encids=predict_enc_ids)
                    else:
                        logging.info("predict_enc_ids is None or empty")
                    # create criteria update task for patients who do not need to predict
                    t_fin = dt.datetime.now()
                    await self.run_trews_alert(
                        message['job_id'],
                        message['hosp'],
                        excluded_enc_ids=predict_enc_ids)
                    t_end = dt.datetime.now()
                    self.cloudwatch_logger.push_many(
                        dimension_name='AlertServer',
                        metric_names=[
                            'e2e_time_{}{}'.format(
                                message['hosp'],
                                '_short' if self.push_based else ''),
                            'criteria_time_{}{}'.format(
                                message['hosp'],
                                '_short' if self.push_based else ''),
                        ],
                        metric_values=[
                            (t_end - t_start).total_seconds(),
                            (t_end - t_fin).total_seconds(),
                        ],
                        metric_units=['Seconds', 'Seconds'])
                elif message.get('hosp') in self.hospital_to_predict:
                    if self.model == 'lmc':
                        self.garbage_collect_suppression_tasks(message['hosp'])
                    self.job_status[message['job_id']] = {'t_start': t_start}
                    self.predictor_manager.cancel_predict_tasks(
                        job_id=message['hosp'])
                    self.predictor_manager.create_predict_tasks(
                        hosp=message['hosp'],
                        time=message['time'],
                        job_id=message['job_id'])
                else:
                    logging.info("skip prediction for msg: {}".format(message))
                    t_fin = dt.datetime.now()
                    await self.run_trews_alert(message['job_id'],
                                               message['hosp'])
                    t_end = dt.datetime.now()
                    self.cloudwatch_logger.push_many(
                        dimension_name='AlertServer',
                        metric_names=[
                            'e2e_time_{}'.format(message['hosp']),
                            'criteria_time_{}'.format(message['hosp']),
                        ],
                        metric_values=[
                            (t_end - t_start).total_seconds(),
                            (t_end - t_fin).total_seconds(),
                        ],
                        metric_units=['Seconds', 'Seconds'])
            elif self.model == 'trews':
                await self.run_trews_alert(message['job_id'], message['hosp'])
            else:
                logging.error("Unknown suppression model {}".format(
                    self.model))
        else:
            logging.error("Don't know how to process this message")

    def start(self):
        ''' Start the alert server and queue consumer '''
        self.loop.run_until_complete(self.async_init())
        consumer_future = asyncio.ensure_future(self.alert_queue_consumer())
        server_future = self.loop.run_until_complete(
            asyncio.start_server(self.connection_handler,
                                 self.alert_dns,
                                 self.alert_server_port,
                                 loop=self.loop))
        logging.info('Serving on {}'.format(
            server_future.sockets[0].getsockname()))
        # Run server until Ctrl+C is pressed
        try:
            self.loop.run_forever()
        except KeyboardInterrupt:
            print("Exiting")
            consumer_future.cancel()
            # Close the server
            logging.info('received stop signal, cancelling tasks...')
            for task in asyncio.Task.all_tasks():
                task.cancel()
            logging.info('bye, exiting in a minute...')
            server_future.close()
            self.loop.run_until_complete(server_future.wait_closed())
            self.loop.stop()
        finally:
            self.loop.close()
コード例 #3
0
ファイル: jhapi.py プロジェクト: cloud-cds/cds-stack
class JHAPIConfig:
  def __init__(self, hospital, lookback_hours, jhapi_server, jhapi_id,
               jhapi_secret, lookback_days=None, op_lookback_days=None):
    if jhapi_server not in servers:
      raise ValueError("Incorrect server provided")
    if int(lookback_hours) > 72:
      raise ValueError("Lookback hours must be less than 72 hours")
    self.jhapi_server = jhapi_server
    self.server = servers[jhapi_server]
    self.hospital = hospital
    self.lookback_hours = int(lookback_hours)
    self.lookback_days = int(lookback_days) if lookback_days else int(int(lookback_hours)/24.0 + 1)
    self.op_lookback_days = op_lookback_days
    self.from_date = (dt.datetime.now() + dt.timedelta(days=1)).strftime('%Y-%m-%d')
    tomorrow = dt.datetime.now() + dt.timedelta(days=1)
    self.dateFrom = (tomorrow - dt.timedelta(days=self.lookback_days)).strftime('%Y-%m-%d')
    self.dateTo = tomorrow.strftime('%Y-%m-%d')
    self.headers = {
      'client_id': jhapi_id,
      'client_secret': jhapi_secret,
      'User-Agent': ''
    }
    self.cloudwatch_logger = Cloudwatch()



  def make_requests(self, ctxt, endpoint, payloads, http_method='GET', url_type=None, server_type='internal'):
    # Define variables
    server = self.server if server_type == 'internal' else servers['{}-{}'.format(self.jhapi_server, server_type)]
    url = "{}{}".format(server, endpoint)
    request_settings = self.generate_request_settings(http_method, url, payloads, url_type)
    semaphore = asyncio.Semaphore(ctxt.flags.JHAPI_SEMAPHORE, loop=ctxt.loop)
    base = ctxt.flags.JHAPI_BACKOFF_BASE
    max_backoff = ctxt.flags.JHAPI_BACKOFF_MAX
    session_attempts = ctxt.flags.JHAPI_ATTEMPTS_SESSION
    request_attempts = ctxt.flags.JHAPI_ATTEMPTS_REQUEST
    # Asyncronous task to make a request
    async def fetch(session, sem, setting):
      success = 0
      error = 0
      for i in range(request_attempts):
        try:
          async with sem:
            async with session.request(**setting) as response:
              if response.status != 200:
                body = await response.text()
                logging.error("Status={}\tMessage={}\tRequest={}".format(response.status, body, setting))
                response = None
                error += 1
              else:
                response = await response.json()
                success += 1
              break
        except IOError as e:
          if i < request_attempts - 1 and not e.errno in (104): # Connection reset by peer
            logging.error(e)
            wait_time = min(((base**i) + random.uniform(0, 1)), max_backoff)
            sleep(wait_time)
          else:
            logging.error("Request IOError: Request={}".format(setting))
            raise Exception("Fail to request URL {}".format(url))
        except Exception as e:
          logging.error("Request exception: Request={}".format(setting))
          if i < request_attempts - 1 and str(e) != 'Session is closed':
            logging.error(e)
            wait_time = min(((base**i) + random.uniform(0, 1)), max_backoff)
            sleep(wait_time)
          else:
            raise Exception("Fail to request URL {}".format(url))
      return response, i+1, success, error


    # Get the client session and create a task for each request
    async def run(request_settings, semaphore, loop):
      async with ClientSession(headers=self.headers, loop=ctxt.loop) as session:
        tasks = [asyncio.ensure_future(fetch(session, semaphore, setting),
                                       loop=loop) for setting in request_settings]
        return await asyncio.gather(*tasks)

    # Start the run task to make all requests
    for attempt in range(session_attempts):
      try:
        task = run(request_settings, semaphore, ctxt.loop)
        future = asyncio.ensure_future(task, loop=ctxt.loop)
        ctxt.loop.run_until_complete(future)
        break
      except Exception as e:
        if attempt < session_attempts - 1:
          logging.error("Session Error Caught for URL {}, retrying... {} times".format(url, attempt+1))
          logging.exception(e)
          wait_time = min(((base**attempt) + random.uniform(0, 1)), max_backoff)
          sleep(wait_time)
        else:
          raise Exception("Session failed for URL {}".format(url))

    # Push number of requests to cloudwatch
    logging.info("Made {} requests".format(sum(x[1] for x in future.result())))
    self.cloudwatch_logger.push(
      dimension_name = 'ETL',
      metric_name    = 'requests_made',
      value          = sum(x[1] for x in future.result()),
      unit           = 'Count'
    )
    label = self.hospital + '_' + endpoint.replace('/', '_') + '_' + http_method
    self.cloudwatch_logger.push_many(
      dimension_name  = 'ETL',
      metric_names    = ['{}_success'.format(label), '{}_error'.format(label), 'jh_api_request_success', 'jh_api_request_error'],
      metric_values   = [sum(x[2] for x in future.result()), sum(x[3] for x in future.result()), sum(x[2] for x in future.result()), sum(x[3] for x in future.result())],
      metric_units    = ['Count','Count','Count','Count']
    )
    # Return responses
    return [x[0] for x in future.result()]



  def generate_request_settings(self, http_method, url, payloads=None, url_type=None):
    request_settings = []
    if url_type == 'rest' and http_method == 'GET':
      for payload in payloads:
        u = url + payload
        request_settings.append({'method': http_method,'url': u})
    else:
      for payload in payloads:
        setting = {
          'method': http_method,
          'url': url
        }
        if payload is not None:
          key = 'params' if http_method == 'GET' else 'json'
          setting[key] = payload
        request_settings.append(setting)

    return request_settings


  def extract_bedded_patients(self, ctxt, hospital, limit=None):
    resource = '/facilities/hospital/' + self.hospital + '/beddedpatients'
    responses = self.make_requests(ctxt, resource, [None], 'GET')
    if limit:
      ctxt.log.info("max_num_pats = {}".format(limit))
    df = pd.DataFrame(responses[0]).head(limit) if limit else pd.DataFrame(responses[0])
    if df.empty:
      ctxt.log.error("No beddedpatients.")
      sys.exit()
    return df.assign(hospital = hospital)

  def extract_ed_patients(self, ctxt, hospital, limit=None):
    resource = '/facilities/hospital/' + self.hospital + '/edptntlist?eddept=ADULT'
    responses = self.make_requests(ctxt, resource, [None], 'GET')
    if limit:
      ctxt.log.info("max_num_pats = {}".format(limit))
    df = pd.DataFrame(responses[0]).head(limit) if limit else pd.DataFrame(responses[0])
    return df.assign(hospital = hospital)

  def combine(self, response_list, to_merge):
    if type(response_list) != list:
      raise TypeError("First argument must be a list of responses")
    dfs = pd.DataFrame()
    for idx, df in enumerate(response_list):
      if not df.empty:
        dfs = pd.concat([dfs, df.assign(index_col=idx)])
    if dfs.empty:
      return dfs
    return pd.merge(dfs, to_merge, how='inner', left_on='index_col',
            right_index=True, sort=False).drop('index_col', axis=1)

  def extract_ed_patients_mrn(self, ctxt, ed_patients):
    resource = '/patients/mrn/'
    payloads = [row['pat_id'] for i, row in ed_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'GET', url_type='rest')
    def calculate_age(born):
      today = date.today()
      return today.year - born.year - ((today.month, today.day) < (born.month, born.day))

    for r in responses:
      pat_id = [pid["ID"] for pid in r[0]['IDs'] if pid['Type'] == 'EMRN'][0]
      sex = r[0]['Sex']
      gender = 0 if sex == 'Female' else 1
      dob = parse(r[0]["DateOfBirth"])
      age = calculate_age(dob)
      ed_patients.loc[ed_patients.pat_id == pat_id,'age'] = age
      ed_patients.loc[ed_patients.pat_id == pat_id,'gender'] = gender
    return ed_patients

  def extract_chiefcomplaint(self, ctxt, beddedpatients):
    resource = '/patients/getdata/chiefcomplaint'
    payloads = [{
      "ContactID": {
        "ID": pat['visit_id'],
        "Type": "CSN"
      },
      "DataFormat": None,
      "Items": [
        {
          "ItemNumber": "18100",
          "LineRange": {
            "From": 1,
            "To": 10
          }
        }
      ],
      "RecordID": {
        "ID": pat['pat_id'],
        "Type":"EMRN"
      }
    } for _, pat in beddedpatients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'POST', server_type='epic')
    for r in responses:
      if r:
        raw_items = r['Items'][0]
        new_items = '[' + ','.join(["{{\"reason\" : \"{}\"}}".format(reason) for reason in [item['Value'] for item in raw_items['Lines'] if item['LineNumber'] > 0]]) + ']'
        r['Items'] = new_items
        r['RecordIDs'] = None
        r['ContactIDs'] = [id for id in r['ContactIDs'] if id['Type'] == 'CSN']
    dfs = [pd.DataFrame(r) for r in responses]
    df = self.combine(dfs, beddedpatients[['pat_id', 'visit_id']])
    return df

  def extract_flowsheets(self, ctxt, bedded_patients):
    resource = '/patients/flowsheetrows'
    flowsheet_row_ids = []
    for fid, internal_id_list in flowsheet_ids:
      for internal_id in internal_id_list:
        flowsheet_row_ids.append({'ID': str(internal_id),
                      'Type': 'Internal'})
    payloads = [{
      'ContactID':        pat['visit_id'],
      'ContactIDType':    'CSN',
      'FlowsheetRowIDs':  flowsheet_row_ids,
      'LookbackHours':    self.lookback_hours,
      'PatientID':        pat['pat_id'],
      'PatientIDType':    'EMRN'
    } for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'POST')
    dfs = [pd.DataFrame(r) for r in responses]
    return self.combine(dfs, bedded_patients[['pat_id', 'visit_id']])


  def extract_active_procedures(self, ctxt, bedded_patients):
    resource = '/facilities/hospital/' + self.hospital + '/orders/activeprocedures'
    payloads = [{'csn': pat['visit_id']} for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'GET')
    dfs = [pd.DataFrame(r) for r in responses]
    return self.combine(dfs, bedded_patients[['pat_id', 'visit_id']])

  def extract_treatmentteam(self, ctxt, bedded_patients):
    resource = '/patients/treatmentteam'
    payloads = [{
      'id': pat['visit_id'],
      'idtype': 'csn'
    } for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'GET')
    dfs = [pd.DataFrame(r['TreatmentTeam'] if r else None) for r in responses]
    return self.combine(dfs, bedded_patients[['pat_id', 'visit_id']])

  def extract_lab_orders(self, ctxt, bedded_patients):
    resource = '/patients/labs/procedure'
    procedure_types = []
    for _, ids in lab_procedure_ids:
      procedure_types += ({'Type': 'INTERNAL', 'ID': str(x)} for x in ids)
    payloads = [{
      'Id':                   pat['pat_id'],
      'IdType':               'patient',
      'FromDate':             self.from_date,
      'MaxNumberOfResults':   200,
      'NumberDaysToLookBack': self.lookback_days,
      'ProcedureTypes':       procedure_types
    } for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'POST')
    dfs = [pd.DataFrame(r['ProcedureResults'] if r else None) for r in responses]
    return self.combine(dfs, bedded_patients[['pat_id', 'visit_id']])


  def extract_lab_results(self, ctxt, bedded_patients):
    resource = '/patients/labs/component'
    component_types = []
    for _, cidl in component_ids:
      component_types += ({'Type': 'INTERNAL', 'Value': str(x)} for x in cidl)
    payloads = [{
      'Id':                   pat['pat_id'],
      'IdType':               'patient',
      'FromDate':             self.from_date,
      'MaxNumberOfResults':   200,
      'NumberDaysToLookBack': self.lookback_days,
      'ComponentTypes':       component_types
    } for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'POST')
    dfs = [pd.DataFrame(r['ResultComponents'] if r else None) for r in responses]
    return self.combine(dfs, bedded_patients[['pat_id', 'visit_id']])


  def extract_loc_history(self, ctxt, bedded_patients):
    resource = '/patients/adtlocationhistory'
    payloads = [{
      'id': pat['visit_id'],
      'type': 'CSN'
    } for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'GET')
    dfs = [pd.DataFrame(r) for r in responses]
    return self.combine(dfs, bedded_patients[['pat_id', 'visit_id']])


  def extract_med_orders(self, ctxt, bedded_patients):
    resource = '/patients/medications'
    payloads = [{
      'id':           pat['pat_id'],
      'dayslookback': str(self.lookback_days),
      'searchtype':   'IP'
    } for _, pat in bedded_patients.iterrows()] + \
    [{
      'id':           pat['pat_id'],
      'dayslookback': str(self.op_lookback_days),
      'searchtype':   'OP'
    } for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'GET')
    dfs = [pd.DataFrame(r) for r in responses]
    half = len(dfs)//2
    med_ip = self.combine(dfs[:half], bedded_patients[['pat_id', 'visit_id']])
    med_op = self.combine(dfs[half:], bedded_patients[['pat_id', 'visit_id']])
    def remove_inpatient(val):
        if val is None or str(val) == 'nan':
            return val
        for item in val:
          if item == 'OrderMode' and val[item] == 'Outpatient':
            return val
        return {}
    med_op['MedicationOrders'] = med_op['MedicationOrders'].apply(remove_inpatient)
    return pd.concat([med_ip, med_op]).reset_index(drop=True)


  def extract_med_admin(self, ctxt, med_orders):
    if med_orders is None or med_orders.empty:
      return None
    else:
      resource = '/patients/medicationadministrationhistory'
      payloads = [{
        'ContactID':        order['visit_id'],
        'ContactIDType':    'CSN',
        'OrderIDs':         list(itertools.chain.from_iterable(order['ids'])),
        'PatientID':        order['pat_id']
      } for _, order in med_orders.iterrows()]
      responses = self.make_requests(ctxt, resource, payloads, 'POST')
      dfs = [pd.DataFrame(r) for r in responses]
      return self.combine(dfs, med_orders[['pat_id', 'visit_id']])


  def extract_notes(self, ctxt, bedded_patients):
    resource = '/patients/documents/list'
    payloads = [{
      'id'       : pat['pat_id'],
      'dateFrom' : self.dateFrom,
      'dateTo'   : self.dateTo
    } for _, pat in bedded_patients.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'GET')
    logging.info('#NOTES PAYLOADS: %s' % len(payloads))
    logging.info('#NOTES RESPONSES: %s' % len(responses))
    dfs = [pd.DataFrame(r['DocumentListData'] if r else None) for r in responses]
    df = self.combine(dfs, bedded_patients[['pat_id']])
    if not df.empty:
      not_empty_idx = df.Key.str.len() > 0
      df = df[not_empty_idx].reset_index()
    return df


  def extract_note_texts(self, ctxt, notes):
    if not notes.empty:
      resource = '/patients/documents/text'
      payloads = [{ 'key' : note['Key'] } for _, note in notes.iterrows()]
      responses = self.make_requests(ctxt, resource, payloads, 'GET')
      logging.info('#NOTE TEXTS PAYLOADS: %s' % len(payloads))
      logging.info('#NOTE TEXTS RESPONSES: %s' % len(responses))
      dfs = [
        pd.DataFrame([{'DocumentText': r['DocumentText']}] if r else None)
        for r in responses
      ]
      return self.combine(dfs, notes[['Key']])
    return pd.DataFrame()


  def extract_contacts(self, ctxt, pat_id_list):
    if not pat_id_list:
      return pd.DataFrame()
    resource = '/patients/contacts'
    pat_id_df = pd.DataFrame(pat_id_list)
    # Get rid of fake patients by filtering out incorrect pat_ids
    pat_id_df = pat_id_df[pat_id_df['pat_id'].str.contains('E.*')]
    payloads = [{
      'id'       : pat['visit_id'],
      'idtype'   : 'csn',
      'dateFrom' : self.dateFrom,
      'dateTo'   : self.dateTo,
    } for _, pat in pat_id_df.iterrows()]
    responses = self.make_requests(ctxt, resource, payloads, 'GET')
    response_dfs = [pd.DataFrame(r['Contacts'] if r else None) for r in responses]
    dfs = pd.concat(response_dfs)
    if dfs.empty:
      return None
    else:
      return pd.merge(pat_id_df, dfs, left_on='visit_id', right_on='CSN')


  def push_notifications(self, ctxt, notifications):
    if ctxt.flags.TREWS_ETL_EPIC_NOTIFICATIONS:
      logging.info("pushing notifications to epic")
      resource = '/patients/addflowsheetvalue'
      load_tz='US/Eastern'
      t_utc = dt.datetime.utcnow().replace(tzinfo=pytz.utc)
      current_time = str(t_utc.astimezone(pytz.timezone(load_tz)))
      payloads = [{
        'PatientID':            n['pat_id'],
        'ContactID':            n['visit_id'],
        'UserID':               'WSEPSIS',
        'FlowsheetID':          '9490',
        'Value':                n['count'],
        'InstantValueTaken':    current_time,
        'FlowsheetTemplateID':  '304700006',
      } for n in notifications]
      for payload in payloads:
        logging.info('%s NOTIFY %s %s %s' % (payload['InstantValueTaken'],
                                             payload['PatientID'],
                                             payload['ContactID'],
                                             payload['Value']))
      self.make_requests(ctxt, resource, payloads, 'POST')
      logging.info("pushed notifications to epic")
    else:
      logging.info("not pushing notifications to epic")