Ejemplo n.º 1
0
def inference(dataset, decays):
    adj, events, n = load(dataset)

    tick_events = events_to_tick_events(events, n)

    learner = HawkesExpKern(decays=decays, penalty="l1", solver="agd", C=1000, verbose=True)
    learner.fit(tick_events)
    influence_matrix = learner.adjacency
    baseline = learner.baseline
    print("score = {}".format(learner.score()))

    with open("./model/" + dataset + ".pickle", "wb") as f:
        pickle.dump([influence_matrix, baseline, decays, n], f)
    return
Ejemplo n.º 2
0
    def test_HawkesExpKern_score(self):
        """...Test HawkesExpKern score method
        """
        n_nodes = 2
        n_realizations = 3

        train_events = [[
            np.cumsum(np.random.rand(4 + i)) for i in range(n_nodes)
        ] for _ in range(n_realizations)]

        test_events = [[
            np.cumsum(np.random.rand(4 + i)) for i in range(n_nodes)
        ] for _ in range(n_realizations)]

        learner = HawkesExpKern(self.decays)

        msg = '^You must either call `fit` before `score` or provide events$'
        with self.assertRaisesRegex(ValueError, msg):
            learner.score()

        given_baseline = np.random.rand(n_nodes)
        given_adjacency = np.random.rand(n_nodes, n_nodes)

        learner.fit(train_events)

        train_score_current_coeffs = learner.score()
        self.assertAlmostEqual(train_score_current_coeffs, 2.0855840)

        train_score_given_coeffs = learner.score(baseline=given_baseline,
                                                 adjacency=given_adjacency)
        self.assertAlmostEqual(train_score_given_coeffs, 0.59502417)

        test_score_current_coeffs = learner.score(test_events)
        self.assertAlmostEqual(test_score_current_coeffs, 1.6001762)

        test_score_given_coeffs = learner.score(test_events,
                                                baseline=given_baseline,
                                                adjacency=given_adjacency)
        self.assertAlmostEqual(test_score_given_coeffs, 0.89322199)