def score_matches(self, m_new, current_day, rng=None):
        """ This function checks a new risk message against all previous messages, and assigns to the closest one in a brute force manner"""
        best_cluster = hash_to_cluster(m_new)
        best_message = None
        best_score = -1
        for i in range(current_day-3, current_day+1):
            for cluster_id, messages in self.clusters_by_day[i].items():
                for m_enc in messages:
                    obs_uid, risk, day, unobs_uid = decode_message(m_enc)
                    if m_new.uid == obs_uid and m_new.day == day:
                        best_cluster = cluster_id
                        best_message = m_enc
                        best_score = 3
                        break
                    elif compare_uids(m_new.uid, obs_uid, 1) and m_new.day - 1 == day and m_new.risk == risk:
                        best_cluster = cluster_id
                        best_message = m_enc
                        best_score = 2
                    elif compare_uids(m_new.uid, obs_uid, 2) and m_new.day - 2 == day and best_score < 1:
                        best_cluster = cluster_id
                        best_message = m_enc
                        best_score = 1
                    elif compare_uids(m_new.uid, obs_uid, 3) and m_new.day - 3 == day and best_score < 0:
                        best_cluster = cluster_id
                        best_message = m_enc
                        best_score = 0
                    else:
                        best_cluster = cluster_id
                        best_message = m_enc
                        best_score = -1
                if best_score == 3:
                    break
            if best_score == 3:
                break
        # print(f"best_cluster: {best_cluster}, m_new: {m_new}, best_score: {best_score}")
        # print(self.clusters)

        if best_message:
            best_message = decode_message(best_message)
        return best_cluster, best_message, best_score
Exemple #2
0
def messages_to_np(human):
    ms_enc = []
    for day, clusters in human.clusters.clusters_by_day.items():
        for cluster_id, messages in clusters.items():
            # TODO: take an average over the risks for that day
            if not any(messages):
                continue
            ms_enc.append([
                cluster_id,
                decode_message(messages[0]).risk,
                len(messages), day
            ])
    return np.array(ms_enc)
 def score_two_messages(self, update_message, risk_message):
     """ This function takes in two messages and scores how well they match"""
     obs_uid, risk, day, unobs_uid = decode_message(risk_message)
     if update_message.uid == obs_uid and update_message.day == day and update_message.risk == risk:
         score = 3
     elif compare_uids(update_message.uid, obs_uid, 1) and update_message.day - 1 == day and update_message.risk == risk:
         score = 2
     elif compare_uids(update_message.uid, obs_uid, 2) and update_message.day - 2 == day and update_message.risk == risk:
         score = 1
     elif compare_uids(update_message.uid, obs_uid, 3) and update_message.day - 3 == day and update_message.risk == risk:
         score = 0
     else:
         score = -1
     return score
    def add_messages(self, messages, current_day, rng=None):
        """ This function clusters new messages by scoring them against old messages in a sort of naive nearest neighbors approach"""
        for message in messages:
            m_dec = decode_message(message)
            # otherwise score against previous messages
            best_cluster, best_message, best_score = self.score_matches(m_dec, current_day, rng=rng)
            if best_score >= 0:
                cluster_id = best_cluster
            else:
                cluster_id = hash_to_cluster(m_dec)

            self.all_messages.append(message)
            self.clusters[cluster_id].append(message)
            self.add_to_clusters_by_day(cluster_id, m_dec.day, message)
    def update_risk_encounters(cls, human, messages):
        """ This function updates an individual's risk based on the receipt of a new message"""
        for message in messages:
            # if you already have a positive test result, ya risky.
            if human.risk == np.log(1.):
                human.risk = np.log(1.)
                return

            # if the encounter message indicates they had a positive test result, increment counter
            message = decode_message(message)
            if message.risk == 15:
                human.tested_positive_contact_count += 1

            init_population_level_risk = 0.01
            expo = (1 - RISK_TRANSMISSION_PROBA) ** human.tested_positive_contact_count
            tmp = (1. - init_population_level_risk) * (1. - expo)
            mask = tmp < init_population_level_risk

            if mask:
                human.risk = np.log(init_population_level_risk) + np.log1p(tmp / init_population_level_risk)
            else:
                human.risk = np.log(1. - init_population_level_risk) + np.log1p(-expo) + np.log1p(init_population_level_risk / tmp)
    def update_records(self, update_messages, human):
        if not update_messages:
            return self
        grouped_update_messages = self.group_by_received_at(update_messages)
        for received_at, update_messages in grouped_update_messages.items():

            # num days x num clusters
            cluster_cards = np.zeros((max(self.clusters_by_day.keys())+1,  max(self.clusters.keys())+1))
            update_cards = np.zeros((max(self.clusters_by_day.keys())+1, 1))

            # figure out the cardinality of each day's message set
            for day, clusters in self.clusters_by_day.items():
                for cluster_id, messages in clusters.items():
                    cluster_cards[day][cluster_id] = len(messages)

            for update_message in update_messages:
                update_cards[update_message.day] += 1

            # find the nearest cardinality cluster
            perfect_signatures = np.where((cluster_cards == update_cards).all(axis=0))[0]
            if not any(perfect_signatures):
                # calculate the wasserstein distance between every signature
                scores = []
                for cluster_idx in range(cluster_cards.shape[1]):
                    scores.append(dist(cluster_cards[:, cluster_idx], update_cards.reshape(-1)))
                best_cluster = int(np.argmin(scores))

                # for each day
                for day in range(len(update_cards)):
                    cur_cardinality = int(cluster_cards[day, best_cluster])
                    target_cardinality = int(update_cards[day])

                    # if (and while) the cardinality is not what it should be, as determined by the update_messages
                    while cur_cardinality - target_cardinality != 0:
                        # print(f"day: {day}, cur_cardinality: {cur_cardinality}, target_cardinality: {target_cardinality}")
                        # if we need to remove messages from this cluster on this day,
                        if cur_cardinality > target_cardinality:
                            best_score = -1
                            best_message = None
                            new_cluster_id = None

                            # then for each message in that day/cluster,
                            for message in self.clusters_by_day[day][best_cluster]:
                                for cluster_id, messages in self.clusters_by_day[day].items():
                                    if cluster_id == best_cluster:
                                        continue

                                    # and for each alternative cluster on that day
                                    for candidate_cluster_message in messages:
                                        # check if it's a good cluster to move this message to
                                        score = self.score_two_messages(decode_message(candidate_cluster_message), message)
                                        if (score > best_score or not best_message):
                                            best_message = message
                                            new_cluster_id = cluster_id

                            # if there are no other clusters on that day make a new cluster
                            if not best_message:
                                best_message = message
                                message = decode_message(message)
                                new_cluster_id = hash_to_cluster(message)
                            best_message = decode_message(best_message)

                            # for the message which best fits another cluster, move it there
                            self.update_record(best_cluster, new_cluster_id, best_message, best_message)
                            cur_cardinality -= 1
                            # print(f"removing from cluster {best_cluster} to cluster {new_cluster_id} on day {day}")

                        #otherwise we need to add messages to this cluster/day
                        else:
                            # so look for messages which closely match our update messages, and add them
                            for update_message in update_messages:
                                if update_message.day == day:
                                    break
                            best_score = -2
                            best_message = None
                            old_cluster_id = None
                            for cluster_id, messages in self.clusters_by_day[day].items():
                                for message in messages:
                                    score = self.score_two_messages(update_message, message)
                                    if (score > best_score and cluster_id != best_cluster):
                                        best_message = message
                                        old_cluster_id = cluster_id

                            best_message = decode_message(best_message)
                            updated_message = Message(best_message.uid, update_message.new_risk, best_message.day, best_message.unobs_id)
                            # print(f"adding from cluster {old_cluster_id} to cluster {best_cluster} on day {day}")
                            self.update_record(old_cluster_id, best_cluster, best_message, updated_message)
                            cur_cardinality += 1
            else:
                best_cluster = self.score_clusters(update_messages, perfect_signatures)
            for update_message in update_messages:
                best_score = -1
                best_message = self.clusters_by_day[update_message.day][best_cluster][0]
                for risk_message in self.clusters_by_day[update_message.day][best_cluster]:
                    score = self.score_two_messages(update_message, risk_message)
                    if score > best_score:
                        best_message = risk_message
                best_message = decode_message(best_message)
                updated_message = Message(best_message.uid, update_message.new_risk, best_message.day, best_message.unobs_id)
                self.update_record(best_cluster, best_cluster, best_message, updated_message)
        return self
Exemple #7
0
def main(args=None):
    if args is None:
        args = parser.parse_args()
    rng = np.random.RandomState(args.seed)

    # check that the plot_dir exists:
    if args.plot_path and not os.path.isdir(args.plot_path):
        os.mkdir(args.plot_path)

    # joblib sometimes takes a string and sometimes an int
    if args.mp_batchsize == -1:
        mp_batchsize = "auto"
    else:
        mp_batchsize = args.mp_batchsize

    # iterate the logs and init people
    with zipfile.ZipFile(args.data_path, 'r') as zf:
        start_logs = pickle.load(zf.open(zf.namelist()[0], 'r'))
        end_logs = pickle.load(zf.open(zf.namelist()[-1], 'r'))
        start = start_logs[0]['time']
        end = end_logs[-1]['time']
        total_days = (end - start).days
        all_params = []
        for idx, pkl in enumerate(zf.namelist()):
            if idx > args.max_pickles:
                break
            all_params.append({
                "pkl_name": pkl,
                "start": start,
                "data_path": args.data_path
            })

    print("initializing humans from logs.")
    with Parallel(n_jobs=args.n_jobs,
                  batch_size=mp_batchsize,
                  backend=args.mp_backend,
                  verbose=10) as parallel:
        results = parallel(
            (delayed(init_humans)(params) for params in all_params))

    humans = defaultdict(list)
    all_possible_symptoms = set()
    for result in results:
        for human in result[0]:
            humans[human['name']].append(human)
        for symp in result[1]:
            all_possible_symptoms.add(symp)

    hd = {}
    for hid, human_splits in humans.items():
        merged_human = DummyHuman(name=human_splits[0]['name'])
        for human in human_splits:
            merged_human.merge(human)
        merged_human.uid = create_new_uid(rng)
        hd[hid] = merged_human

    # select the risk prediction model to embed in messaging protocol

    RiskModel = pick_risk_model(args.risk_model)

    with zipfile.ZipFile(args.data_path, 'r') as zf:
        start_pkl = zf.namelist()[0]
    for current_day in range(total_days):
        if args.max_num_days <= current_day:
            break

        print(f"day {current_day} of {total_days}")
        days_logs, start_pkl = get_days_worth_of_logs(args.data_path, start,
                                                      start_pkl, current_day)

        all_params = []
        for human in hd.values():
            encounters = days_logs[human.name]
            log_path = f'{os.path.dirname(args.data_path)}/daily_outputs/{current_day}/{human.name[6:]}/'
            all_params.append({
                "start": start,
                "current_day": current_day,
                "encounters": encounters,
                "rng": rng,
                "all_possible_symptoms": all_possible_symptoms,
                "human": human.__dict__,
                "save_training_data": args.save_training_data,
                "log_path": log_path,
                "random_clusters": args.random_clusters
            })
            # go about your day accruing encounters and clustering them
            for encounter in encounters:
                encounter_time = encounter['time']
                unobs = encounter['payload']['unobserved']
                encountered_human = hd[unobs['human2']['human_id']]
                message = encode_message(
                    encountered_human.cur_message(current_day, RiskModel))
                encountered_human.sent_messages[
                    str(unobs['human1']['human_id']) + "_" +
                    str(encounter_time)] = message
                human.messages.append(message)

                got_exposed = encounter['payload']['unobserved']['human1'][
                    'got_exposed']
                if got_exposed:
                    human.exposure_message = message

            # if the encounter happened within the last 14 days, and your symptoms started at most 3 days after your contact
            if RiskModel.quantize_risk(
                    human.start_risk) != RiskModel.quantize_risk(human.risk):
                sent_at = start + datetime.timedelta(
                    days=current_day, minutes=rng.randint(low=0, high=1440))
                for k, m in human.sent_messages.items():
                    message = decode_message(m)
                    if current_day - message.day < 14:
                        # add the update message to the receiver's inbox
                        update_message = encode_update_message(
                            human.cur_message_risk_update(
                                message.day, message.risk, sent_at, RiskModel))
                        hd[k.split("_")[0]].update_messages.append(
                            update_message)
                human.sent_messages = {}
            human.uid = update_uid(human.uid, rng)

        with Parallel(n_jobs=args.n_jobs,
                      batch_size=mp_batchsize,
                      backend=args.mp_backend,
                      verbose=10) as parallel:
            human_dicts = parallel(
                (delayed(proc_human)(params) for params in all_params))

        for human_dict in human_dicts:
            human = DummyHuman(name=human_dict['name']).merge(human_dict)
            hd[human.name] = human
        if args.plot_daily:
            daily_risks = [(np.e**human.risk,
                            human.is_infectious(encounter['time'])[0],
                            human.name) for human in hd.values()]
            hist_plot(daily_risks,
                      f"{args.plot_path}day_{str(current_day).zfill(3)}.png")

    # print out the clusters
    clusters = []
    for human in hd.values():
        clusters.append(dict(human.clusters.clusters))
    json.dump(clusters, open(args.cluster_path, 'w'))
Exemple #8
0
    os.mkdir(INDIVIDUAL_CLUSTER_PATH)

# load the cluster data
everyones_clustered_messages = json.load(open(CLUSTER_PATH, 'r'))

# gather some high level statistics about the clusters (how many groups, total and unique contacts)
all_groups = []
all_total_num_contacts = []
all_unique_people_contacted = []
for someones_clustered_messages in everyones_clustered_messages:
    groups = defaultdict(list)
    unique_people_contacted = set()
    total_num_contacts = 0
    for assignment, m_encs in someones_clustered_messages.items():
        for m_enc in m_encs:
            obs_uid, obs_risk, m_sent, unobs_uid = decode_message(m_enc)
            groups[assignment].append(unobs_uid)
            unique_people_contacted.add(unobs_uid)
            total_num_contacts += 1
    all_groups.append(dict(groups))
    all_unique_people_contacted.append(unique_people_contacted)
    all_total_num_contacts.append(total_num_contacts)

# count the number of people in each group
all_count_people_in_group = []
all_number_of_groups = [len(groups) for groups in all_groups]
for group in all_groups:
    count_people_in_group = []
    for g, ps in group.items():
        cnt = Counter()
        num_people_in_group = len(ps)