コード例 #1
0
    def update_record(self, old_cluster_id, new_cluster_id, message, updated_message):
        """ This function updates a message in all of the data structures and can change the cluster that this message is in"""
        old_m_enc = encode_message(message)
        new_m_enc = encode_message(updated_message)
        del self.clusters[old_cluster_id][self.clusters[old_cluster_id].index(old_m_enc)]
        del self.all_messages[self.all_messages.index(old_m_enc)]
        del self.clusters_by_day[message.day][old_cluster_id][self.clusters_by_day[message.day][old_cluster_id].index(old_m_enc)]

        self.clusters[new_cluster_id].append(encode_message(updated_message))
        self.all_messages.append(new_m_enc)
        self.add_to_clusters_by_day(new_cluster_id, updated_message.day, new_m_enc)
コード例 #2
0
    def test_add_message_to_cluster_same_cluster_run(self):
        """
        Tests that the add_message_to_cluster function adds messages with the same uid on the same day to the same cluster.
        """
        # make new old message clusters
        message = Message(0, 0, 0, "human:1")
        clusters = Clusters()
        clusters.add_messages([encode_message(message)], 0)

        # make new message
        new_message = Message(0, 0, 0, "human:1")
        # add message to clusters
        clusters.add_messages([encode_message(new_message)], 0)
        self.assertEqual(len(clusters), 1)
コード例 #3
0
    def test_purge(self):
        """ Tests the purge functionality"""
        message1 = Message(0, 0, 0, "human:0")
        message2 = Message(15, 0, 1, "human:0")
        clusters = Clusters()
        clusters.add_messages([encode_message(message1)], 0)
        clusters.add_messages([encode_message(message2)], 0)

        clusters.purge(13)
        self.assertEqual(len(clusters), 2)
        clusters.purge(14)
        self.assertEqual(len(clusters), 1)
        clusters.purge(15)
        self.assertEqual(len(clusters), 0)
コード例 #4
0
    def test_add_message_to_cluster_new_cluster_run(self):
        """
        Tests messages with mutually exclusive uids on the same day are scored lowly
        """
        # make new old message clusters
        message = Message(0, 0, 0, "human:1")
        clusters = Clusters()
        clusters.add_messages([encode_message(message)], 0)

        # make new message
        new_message = Message(1, 0, 0, "human:1")
        # add message to clusters

        clusters.add_messages([encode_message(new_message)], 0)
        num_clusters = len(clusters)
        self.assertEqual(num_clusters, 2)
コード例 #5
0
 def test_score_bad_match_same_day_run(self):
     """
     Tests messages with mutually exclusive uids on the same day are scored lowly
     """
     # uid, risk, day, time_received, true sender id
     current_day = 0
     message1 = Message(0, 0, current_day, "human:0")
     message2 = Message(1, 0, current_day, "human:1")
     clusters = Clusters()
     clusters.add_messages([encode_message(message1)], current_day)
     best_cluster, best_message, best_score = clusters.score_matches(
         message2, current_day)
     self.assertEqual(best_score, -1)
     self.assertEqual(message1, best_message)
コード例 #6
0
 def test_score_bad_match_one_day_run(self):
     """
     Tests messages with mutually exclusive uids seperated by a day are scored lowly
     """
     # uid, risk, day, true sender id
     message1 = Message(0, 0, 0, "human:1")
     message2 = Message(6, 0, 1, "human:1")
     clusters = Clusters()
     clusters.add_messages([encode_message(message1)], 0)
     best_cluster, best_message, best_score = clusters.score_matches(
         message2, 1)
     self.assertEqual(best_cluster, 0)
     self.assertEqual(best_message, message1)
     self.assertEqual(best_score, -1)
コード例 #7
0
 def test_score_good_match_same_day_run(self):
     """
     Tests messages with the same uids on the same day are scored highly
     """
     # uid, risk, day, true sender id
     current_day = 0
     message1 = Message(0, 0, current_day, "human:1")
     message2 = Message(0, 0, current_day, "human:1")
     clusters = Clusters()
     clusters.add_messages([encode_message(message1)], current_day)
     best_cluster, best_message, best_score = clusters.score_matches(
         message2, current_day)
     self.assertEqual(best_cluster, 0)
     self.assertEqual(best_message, message1)
     self.assertEqual(best_score, 3)
コード例 #8
0
    def test_score_good_match_one_day_run(self):
        """
        Tests messages with similar uids on the different day are scored mediumly
        """
        # uid, risk, day, true sender id
        current_day = 0
        clusters = Clusters()
        message1 = Message(0, 0, 0, "human:1")
        clusters.add_messages([encode_message(message1)], current_day)
        message2 = Message(1, 0, 1, "human:1")

        best_cluster, best_message, best_score = clusters.score_matches(
            message2, 1)
        self.assertEqual(best_cluster, 0)
        self.assertEqual(best_message, message1)
        self.assertEqual(best_score, 2)
コード例 #9
0
def main(args=None):
    if args is None:
        args = parser.parse_args()
    rng = np.random.RandomState(args.seed)

    # check that the plot_dir exists:
    if args.plot_path and not os.path.isdir(args.plot_path):
        os.mkdir(args.plot_path)

    # joblib sometimes takes a string and sometimes an int
    if args.mp_batchsize == -1:
        mp_batchsize = "auto"
    else:
        mp_batchsize = args.mp_batchsize

    # iterate the logs and init people
    with zipfile.ZipFile(args.data_path, 'r') as zf:
        start_logs = pickle.load(zf.open(zf.namelist()[0], 'r'))
        end_logs = pickle.load(zf.open(zf.namelist()[-1], 'r'))
        start = start_logs[0]['time']
        end = end_logs[-1]['time']
        total_days = (end - start).days
        all_params = []
        for idx, pkl in enumerate(zf.namelist()):
            if idx > args.max_pickles:
                break
            all_params.append({
                "pkl_name": pkl,
                "start": start,
                "data_path": args.data_path
            })

    print("initializing humans from logs.")
    with Parallel(n_jobs=args.n_jobs,
                  batch_size=mp_batchsize,
                  backend=args.mp_backend,
                  verbose=10) as parallel:
        results = parallel(
            (delayed(init_humans)(params) for params in all_params))

    humans = defaultdict(list)
    all_possible_symptoms = set()
    for result in results:
        for human in result[0]:
            humans[human['name']].append(human)
        for symp in result[1]:
            all_possible_symptoms.add(symp)

    hd = {}
    for hid, human_splits in humans.items():
        merged_human = DummyHuman(name=human_splits[0]['name'])
        for human in human_splits:
            merged_human.merge(human)
        merged_human.uid = create_new_uid(rng)
        hd[hid] = merged_human

    # select the risk prediction model to embed in messaging protocol

    RiskModel = pick_risk_model(args.risk_model)

    with zipfile.ZipFile(args.data_path, 'r') as zf:
        start_pkl = zf.namelist()[0]
    for current_day in range(total_days):
        if args.max_num_days <= current_day:
            break

        print(f"day {current_day} of {total_days}")
        days_logs, start_pkl = get_days_worth_of_logs(args.data_path, start,
                                                      start_pkl, current_day)

        all_params = []
        for human in hd.values():
            encounters = days_logs[human.name]
            log_path = f'{os.path.dirname(args.data_path)}/daily_outputs/{current_day}/{human.name[6:]}/'
            all_params.append({
                "start": start,
                "current_day": current_day,
                "encounters": encounters,
                "rng": rng,
                "all_possible_symptoms": all_possible_symptoms,
                "human": human.__dict__,
                "save_training_data": args.save_training_data,
                "log_path": log_path,
                "random_clusters": args.random_clusters
            })
            # go about your day accruing encounters and clustering them
            for encounter in encounters:
                encounter_time = encounter['time']
                unobs = encounter['payload']['unobserved']
                encountered_human = hd[unobs['human2']['human_id']]
                message = encode_message(
                    encountered_human.cur_message(current_day, RiskModel))
                encountered_human.sent_messages[
                    str(unobs['human1']['human_id']) + "_" +
                    str(encounter_time)] = message
                human.messages.append(message)

                got_exposed = encounter['payload']['unobserved']['human1'][
                    'got_exposed']
                if got_exposed:
                    human.exposure_message = message

            # if the encounter happened within the last 14 days, and your symptoms started at most 3 days after your contact
            if RiskModel.quantize_risk(
                    human.start_risk) != RiskModel.quantize_risk(human.risk):
                sent_at = start + datetime.timedelta(
                    days=current_day, minutes=rng.randint(low=0, high=1440))
                for k, m in human.sent_messages.items():
                    message = decode_message(m)
                    if current_day - message.day < 14:
                        # add the update message to the receiver's inbox
                        update_message = encode_update_message(
                            human.cur_message_risk_update(
                                message.day, message.risk, sent_at, RiskModel))
                        hd[k.split("_")[0]].update_messages.append(
                            update_message)
                human.sent_messages = {}
            human.uid = update_uid(human.uid, rng)

        with Parallel(n_jobs=args.n_jobs,
                      batch_size=mp_batchsize,
                      backend=args.mp_backend,
                      verbose=10) as parallel:
            human_dicts = parallel(
                (delayed(proc_human)(params) for params in all_params))

        for human_dict in human_dicts:
            human = DummyHuman(name=human_dict['name']).merge(human_dict)
            hd[human.name] = human
        if args.plot_daily:
            daily_risks = [(np.e**human.risk,
                            human.is_infectious(encounter['time'])[0],
                            human.name) for human in hd.values()]
            hist_plot(daily_risks,
                      f"{args.plot_path}day_{str(current_day).zfill(3)}.png")

    # print out the clusters
    clusters = []
    for human in hd.values():
        clusters.append(dict(human.clusters.clusters))
    json.dump(clusters, open(args.cluster_path, 'w'))