def score_matches(self, m_new, current_day, rng=None): """ This function checks a new risk message against all previous messages, and assigns to the closest one in a brute force manner""" best_cluster = hash_to_cluster(m_new) best_message = None best_score = -1 for i in range(current_day-3, current_day+1): for cluster_id, messages in self.clusters_by_day[i].items(): for m_enc in messages: obs_uid, risk, day, unobs_uid = decode_message(m_enc) if m_new.uid == obs_uid and m_new.day == day: best_cluster = cluster_id best_message = m_enc best_score = 3 break elif compare_uids(m_new.uid, obs_uid, 1) and m_new.day - 1 == day and m_new.risk == risk: best_cluster = cluster_id best_message = m_enc best_score = 2 elif compare_uids(m_new.uid, obs_uid, 2) and m_new.day - 2 == day and best_score < 1: best_cluster = cluster_id best_message = m_enc best_score = 1 elif compare_uids(m_new.uid, obs_uid, 3) and m_new.day - 3 == day and best_score < 0: best_cluster = cluster_id best_message = m_enc best_score = 0 else: best_cluster = cluster_id best_message = m_enc best_score = -1 if best_score == 3: break if best_score == 3: break # print(f"best_cluster: {best_cluster}, m_new: {m_new}, best_score: {best_score}") # print(self.clusters) if best_message: best_message = decode_message(best_message) return best_cluster, best_message, best_score
def messages_to_np(human): ms_enc = [] for day, clusters in human.clusters.clusters_by_day.items(): for cluster_id, messages in clusters.items(): # TODO: take an average over the risks for that day if not any(messages): continue ms_enc.append([ cluster_id, decode_message(messages[0]).risk, len(messages), day ]) return np.array(ms_enc)
def score_two_messages(self, update_message, risk_message): """ This function takes in two messages and scores how well they match""" obs_uid, risk, day, unobs_uid = decode_message(risk_message) if update_message.uid == obs_uid and update_message.day == day and update_message.risk == risk: score = 3 elif compare_uids(update_message.uid, obs_uid, 1) and update_message.day - 1 == day and update_message.risk == risk: score = 2 elif compare_uids(update_message.uid, obs_uid, 2) and update_message.day - 2 == day and update_message.risk == risk: score = 1 elif compare_uids(update_message.uid, obs_uid, 3) and update_message.day - 3 == day and update_message.risk == risk: score = 0 else: score = -1 return score
def add_messages(self, messages, current_day, rng=None): """ This function clusters new messages by scoring them against old messages in a sort of naive nearest neighbors approach""" for message in messages: m_dec = decode_message(message) # otherwise score against previous messages best_cluster, best_message, best_score = self.score_matches(m_dec, current_day, rng=rng) if best_score >= 0: cluster_id = best_cluster else: cluster_id = hash_to_cluster(m_dec) self.all_messages.append(message) self.clusters[cluster_id].append(message) self.add_to_clusters_by_day(cluster_id, m_dec.day, message)
def update_risk_encounters(cls, human, messages): """ This function updates an individual's risk based on the receipt of a new message""" for message in messages: # if you already have a positive test result, ya risky. if human.risk == np.log(1.): human.risk = np.log(1.) return # if the encounter message indicates they had a positive test result, increment counter message = decode_message(message) if message.risk == 15: human.tested_positive_contact_count += 1 init_population_level_risk = 0.01 expo = (1 - RISK_TRANSMISSION_PROBA) ** human.tested_positive_contact_count tmp = (1. - init_population_level_risk) * (1. - expo) mask = tmp < init_population_level_risk if mask: human.risk = np.log(init_population_level_risk) + np.log1p(tmp / init_population_level_risk) else: human.risk = np.log(1. - init_population_level_risk) + np.log1p(-expo) + np.log1p(init_population_level_risk / tmp)
def update_records(self, update_messages, human): if not update_messages: return self grouped_update_messages = self.group_by_received_at(update_messages) for received_at, update_messages in grouped_update_messages.items(): # num days x num clusters cluster_cards = np.zeros((max(self.clusters_by_day.keys())+1, max(self.clusters.keys())+1)) update_cards = np.zeros((max(self.clusters_by_day.keys())+1, 1)) # figure out the cardinality of each day's message set for day, clusters in self.clusters_by_day.items(): for cluster_id, messages in clusters.items(): cluster_cards[day][cluster_id] = len(messages) for update_message in update_messages: update_cards[update_message.day] += 1 # find the nearest cardinality cluster perfect_signatures = np.where((cluster_cards == update_cards).all(axis=0))[0] if not any(perfect_signatures): # calculate the wasserstein distance between every signature scores = [] for cluster_idx in range(cluster_cards.shape[1]): scores.append(dist(cluster_cards[:, cluster_idx], update_cards.reshape(-1))) best_cluster = int(np.argmin(scores)) # for each day for day in range(len(update_cards)): cur_cardinality = int(cluster_cards[day, best_cluster]) target_cardinality = int(update_cards[day]) # if (and while) the cardinality is not what it should be, as determined by the update_messages while cur_cardinality - target_cardinality != 0: # print(f"day: {day}, cur_cardinality: {cur_cardinality}, target_cardinality: {target_cardinality}") # if we need to remove messages from this cluster on this day, if cur_cardinality > target_cardinality: best_score = -1 best_message = None new_cluster_id = None # then for each message in that day/cluster, for message in self.clusters_by_day[day][best_cluster]: for cluster_id, messages in self.clusters_by_day[day].items(): if cluster_id == best_cluster: continue # and for each alternative cluster on that day for candidate_cluster_message in messages: # check if it's a good cluster to move this message to score = self.score_two_messages(decode_message(candidate_cluster_message), message) if (score > best_score or not best_message): best_message = message new_cluster_id = cluster_id # if there are no other clusters on that day make a new cluster if not best_message: best_message = message message = decode_message(message) new_cluster_id = hash_to_cluster(message) best_message = decode_message(best_message) # for the message which best fits another cluster, move it there self.update_record(best_cluster, new_cluster_id, best_message, best_message) cur_cardinality -= 1 # print(f"removing from cluster {best_cluster} to cluster {new_cluster_id} on day {day}") #otherwise we need to add messages to this cluster/day else: # so look for messages which closely match our update messages, and add them for update_message in update_messages: if update_message.day == day: break best_score = -2 best_message = None old_cluster_id = None for cluster_id, messages in self.clusters_by_day[day].items(): for message in messages: score = self.score_two_messages(update_message, message) if (score > best_score and cluster_id != best_cluster): best_message = message old_cluster_id = cluster_id best_message = decode_message(best_message) updated_message = Message(best_message.uid, update_message.new_risk, best_message.day, best_message.unobs_id) # print(f"adding from cluster {old_cluster_id} to cluster {best_cluster} on day {day}") self.update_record(old_cluster_id, best_cluster, best_message, updated_message) cur_cardinality += 1 else: best_cluster = self.score_clusters(update_messages, perfect_signatures) for update_message in update_messages: best_score = -1 best_message = self.clusters_by_day[update_message.day][best_cluster][0] for risk_message in self.clusters_by_day[update_message.day][best_cluster]: score = self.score_two_messages(update_message, risk_message) if score > best_score: best_message = risk_message best_message = decode_message(best_message) updated_message = Message(best_message.uid, update_message.new_risk, best_message.day, best_message.unobs_id) self.update_record(best_cluster, best_cluster, best_message, updated_message) return self
def main(args=None): if args is None: args = parser.parse_args() rng = np.random.RandomState(args.seed) # check that the plot_dir exists: if args.plot_path and not os.path.isdir(args.plot_path): os.mkdir(args.plot_path) # joblib sometimes takes a string and sometimes an int if args.mp_batchsize == -1: mp_batchsize = "auto" else: mp_batchsize = args.mp_batchsize # iterate the logs and init people with zipfile.ZipFile(args.data_path, 'r') as zf: start_logs = pickle.load(zf.open(zf.namelist()[0], 'r')) end_logs = pickle.load(zf.open(zf.namelist()[-1], 'r')) start = start_logs[0]['time'] end = end_logs[-1]['time'] total_days = (end - start).days all_params = [] for idx, pkl in enumerate(zf.namelist()): if idx > args.max_pickles: break all_params.append({ "pkl_name": pkl, "start": start, "data_path": args.data_path }) print("initializing humans from logs.") with Parallel(n_jobs=args.n_jobs, batch_size=mp_batchsize, backend=args.mp_backend, verbose=10) as parallel: results = parallel( (delayed(init_humans)(params) for params in all_params)) humans = defaultdict(list) all_possible_symptoms = set() for result in results: for human in result[0]: humans[human['name']].append(human) for symp in result[1]: all_possible_symptoms.add(symp) hd = {} for hid, human_splits in humans.items(): merged_human = DummyHuman(name=human_splits[0]['name']) for human in human_splits: merged_human.merge(human) merged_human.uid = create_new_uid(rng) hd[hid] = merged_human # select the risk prediction model to embed in messaging protocol RiskModel = pick_risk_model(args.risk_model) with zipfile.ZipFile(args.data_path, 'r') as zf: start_pkl = zf.namelist()[0] for current_day in range(total_days): if args.max_num_days <= current_day: break print(f"day {current_day} of {total_days}") days_logs, start_pkl = get_days_worth_of_logs(args.data_path, start, start_pkl, current_day) all_params = [] for human in hd.values(): encounters = days_logs[human.name] log_path = f'{os.path.dirname(args.data_path)}/daily_outputs/{current_day}/{human.name[6:]}/' all_params.append({ "start": start, "current_day": current_day, "encounters": encounters, "rng": rng, "all_possible_symptoms": all_possible_symptoms, "human": human.__dict__, "save_training_data": args.save_training_data, "log_path": log_path, "random_clusters": args.random_clusters }) # go about your day accruing encounters and clustering them for encounter in encounters: encounter_time = encounter['time'] unobs = encounter['payload']['unobserved'] encountered_human = hd[unobs['human2']['human_id']] message = encode_message( encountered_human.cur_message(current_day, RiskModel)) encountered_human.sent_messages[ str(unobs['human1']['human_id']) + "_" + str(encounter_time)] = message human.messages.append(message) got_exposed = encounter['payload']['unobserved']['human1'][ 'got_exposed'] if got_exposed: human.exposure_message = message # if the encounter happened within the last 14 days, and your symptoms started at most 3 days after your contact if RiskModel.quantize_risk( human.start_risk) != RiskModel.quantize_risk(human.risk): sent_at = start + datetime.timedelta( days=current_day, minutes=rng.randint(low=0, high=1440)) for k, m in human.sent_messages.items(): message = decode_message(m) if current_day - message.day < 14: # add the update message to the receiver's inbox update_message = encode_update_message( human.cur_message_risk_update( message.day, message.risk, sent_at, RiskModel)) hd[k.split("_")[0]].update_messages.append( update_message) human.sent_messages = {} human.uid = update_uid(human.uid, rng) with Parallel(n_jobs=args.n_jobs, batch_size=mp_batchsize, backend=args.mp_backend, verbose=10) as parallel: human_dicts = parallel( (delayed(proc_human)(params) for params in all_params)) for human_dict in human_dicts: human = DummyHuman(name=human_dict['name']).merge(human_dict) hd[human.name] = human if args.plot_daily: daily_risks = [(np.e**human.risk, human.is_infectious(encounter['time'])[0], human.name) for human in hd.values()] hist_plot(daily_risks, f"{args.plot_path}day_{str(current_day).zfill(3)}.png") # print out the clusters clusters = [] for human in hd.values(): clusters.append(dict(human.clusters.clusters)) json.dump(clusters, open(args.cluster_path, 'w'))
os.mkdir(INDIVIDUAL_CLUSTER_PATH) # load the cluster data everyones_clustered_messages = json.load(open(CLUSTER_PATH, 'r')) # gather some high level statistics about the clusters (how many groups, total and unique contacts) all_groups = [] all_total_num_contacts = [] all_unique_people_contacted = [] for someones_clustered_messages in everyones_clustered_messages: groups = defaultdict(list) unique_people_contacted = set() total_num_contacts = 0 for assignment, m_encs in someones_clustered_messages.items(): for m_enc in m_encs: obs_uid, obs_risk, m_sent, unobs_uid = decode_message(m_enc) groups[assignment].append(unobs_uid) unique_people_contacted.add(unobs_uid) total_num_contacts += 1 all_groups.append(dict(groups)) all_unique_people_contacted.append(unique_people_contacted) all_total_num_contacts.append(total_num_contacts) # count the number of people in each group all_count_people_in_group = [] all_number_of_groups = [len(groups) for groups in all_groups] for group in all_groups: count_people_in_group = [] for g, ps in group.items(): cnt = Counter() num_people_in_group = len(ps)