def joint_log_prob(field_prior: tfp.distributions.Distribution, matchup_prior: tfp.distributions.Distribution, n_matches, obs_match_counts, obs_match_wins, candidate_field, candidate_matchups): """Joint log probability for a candidate field and matchup distribution""" # Priors ll_field = field_prior.log_prob(candidate_field) # Matchup matrix has n*n entries but only n*(n-1)/2 independent parameters, # because M[i,j]=1-M[j,i] and therefore M[i,i]=0.5. Therefore, only compute # log likelihood over entries where j>i. n = obs_match_counts.shape[0] # matchup_free = fill_triangular_inverse( # tf.slice(candidate_matchups, [0,1], [n-1, n-1]), upper=True) ll_matchups = tf.reduce_sum(matchup_prior.log_prob(candidate_matchups)) # Probability of match counts given field distribution # (observed counts should be a matrix, but the distribution is defined over # a flat vector, so first we flatten the data) flattened_match_counts = tf.reshape(obs_match_counts, [-1]) rv_match_counts = generate.rv_match_counts(candidate_field, n_matches) ll_counts = tf.reduce_sum(rv_match_counts.log_prob(flattened_match_counts)) # Probability of outcomes given match counts and matchups # (observed data is number of wins, but the distribution is defined over # combination of wins and losses, which we can derive from wins and counts) obs_outcome = generate.wins_to_outcomes(obs_match_wins, obs_match_counts) candidate_matchup_matrix = generate.build_matchup_matrix( candidate_matchups, n) rv_match_outcomes = generate.rv_outcomes(obs_match_counts, candidate_matchup_matrix) ll_wins = tf.reduce_sum(rv_match_outcomes.log_prob(obs_outcome)) return ll_field + ll_matchups + ll_counts + ll_wins
def add_vars(self, var_list, kernel_results): self.vars["field"] = var_list[0] self.vars["matchups_free"] = var_list[1] self.vars["matchup_matrix"] = generate.build_matchup_matrix( self.vars["matchups_free"], self.n_archetypes) self.vars["ev"] = tf.reshape( tf.linalg.matmul(self.vars["matchup_matrix"], tf.expand_dims(self.vars["field"], -1)), self.vars["field"].shape) self.vars["kernel_results"] = kernel_results
def log_prob(candidate_field, candidate_matchups, candidate_wait_time): """Joint log probability for a candidate field distribution, matchup distribution, and wait time""" n_hypotheses = candidate_field.shape[0] # Priors ll_field = self.unknown[0].prior().log_prob(candidate_field) ll_matchups = tf.reduce_sum( self.unknown[1].prior().log_prob(candidate_matchups), axis=1) ll_wait = tf.reduce_sum( self.unknown[2].prior().log_prob(candidate_wait_time), axis=1) # Compute the full matchup matrix from the matchup parameters candidate_matchup_matrix = generate.build_matchup_matrix( candidate_matchups, self.n_archetypes) # candidate_score_matrix = self.score_matrix(candidate_wait_time) score_matrix_full = generate.build_score_matrix( tf.reshape(candidate_wait_time, [n_hypotheses, -1]), self.p_find_base, self.p_score, self.n_rounds) candidate_score_matrix = tf.reshape( score_matrix_full, (n_hypotheses, self.n_scores, self.n_scores)) # Approximate the distribution of decks at each record and paired against each record p_deck_given_record_approx, p_opp_deck_given_pl_record = approximate_p_deck( self.n_rounds, self.n_archetypes, candidate_field, candidate_matchup_matrix, candidate_score_matrix) # Construct the probability of observed outcomes according to derived parameters: ll_priors = ll_field + ll_matchups + ll_wait ll_pairings = pairings_log_prob(pairing_counts, p_opp_deck_given_pl_record) ll_records = pairings_log_prob(record_counts, p_deck_given_record_approx) # To add in independent record observations, calculate per-match EV (dimensions c * d * 1) candidate_ev = tf.linalg.matmul( candidate_matchup_matrix, tf.expand_dims(candidate_field, -1)) # Calculate logs of match results and deck probability log_p_win = tf.reshape(tf.math.log(candidate_ev), (n_hypotheses, self.n_archetypes)) # c * d log_p_lose = tf.reshape(tf.math.log(1.0 - candidate_ev), (n_hypotheses, self.n_archetypes)) # c * d log_p_deck = tf.math.log(candidate_field) # c * d # To add in independent matchup observations, calculate per-pairing log prob (dimensions c * d * d) ind_matchups = tfp.distributions.Binomial( probs=candidate_matchup_matrix, total_count=matchup_counts) ll_ind_pairings = ind_matchups.log_prob(matchup_wins) ll_ind_match = tf.reduce_sum(tf.reduce_sum(ll_ind_pairings, axis=-1), axis=-1) # Multiply observed individual results (d) elementwise by log probabilities of results per archetype (c * d) # Then sum over archetypes ll_obs_wins = log_p_win * win_counts ll_obs_losses = log_p_lose * loss_counts ll_obs_counts = log_p_deck * deck_counts ll_independent = tf.reduce_sum(ll_obs_wins + ll_obs_losses + ll_obs_counts, axis=-1) ll = ll_pairings + ll_records + ll_priors + ll_independent + ll_ind_match return ll
def test_ll(session): field = tf.placeholder(dtype=tf.float32, shape=(3, )) matchups = tf.placeholder(dtype=tf.float32, shape=(3, 3)) scores = tf.placeholder(dtype=tf.float32, shape=(5, 5)) pairing_counts = tf.placeholder(dtype=tf.float32, shape=(3, 6)) record_counts = tf.placeholder(dtype=tf.float32, shape=(3, 6)) wait = tf.constant(100.0) # p_deck_given_record_approx, p_opp_deck_given_record = models.approximate_p_deck(2, 3, field, matchups, scores) # ll = models.pairings_log_prob(tf.transpose(pairing_counts), p_deck_given_record_approx) league = models.LeagueModel(3, 2) ll = league.log_prob_fn( tf.transpose(pairing_counts), tf.transpose(record_counts))(field, generate.get_matchup_parameters(matchups), wait) m = np.array([[.5, .7, .2], [.3, .5, .6], [.8, .4, .5]]) f = np.array([.4, .3, .3]) s = np.array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]) n = np.array([[40, 38, 43, 36, 42, 41], [30, 27, 33, 28, 29, 34], [30, 36, 25, 35, 30, 24]]) p = np.array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]) ll_true = session.run(ll, { matchups: m, field: f, scores: s, pairing_counts: p, record_counts: n }) f_normalized = f / math.sqrt(np.square(f).sum()) mf = session.run(generate.get_matchup_parameters(matchups), {matchups: m}) print(mf) print(ll_true) print("Hold matchups constant:") f_modified = [] ll_modified_field = [ll_true] sim_field = [f_normalized.dot(f_normalized).sum()] for i in range(1000): f_modified.append(np.random.dirichlet([1.0, 1.0, 1.0])) ll_modified_field.append( session.run( ll, { matchups: m, field: f_modified[-1], scores: s, pairing_counts: p, record_counts: n })) if ll_modified_field[-1] > ll_true: print(ll_modified_field[-1], f_modified[-1]) normalized = f_modified[-1] / math.sqrt( np.square(f_modified[-1]).sum()) cosine = normalized.dot(f_normalized).sum() sim_field.append(cosine) sim_field = np.array(sim_field) ll_modified_field = np.array(ll_modified_field) ax = sns.scatterplot(sim_field, ll_modified_field - ll_true) ax.hlines(0.0, xmin=-1.0, xmax=1.0, linestyle='dashed') ax.vlines(1.0, ymin=-1.0, ymax=1.0, linestyle='dashed') plt.show() print("Hold field constant:") m_free = tf.placeholder(dtype=tf.float32, shape=(3, )) mm = generate.build_matchup_matrix(m_free, 3) m_modified = [] dist_matchup = [] ll_modified_matchup = [] results = [] # should be optimal: f0 = np.array([0.4, 0.8, 0.3]) m_modified.append(session.run(mm, {m_free: f0})) ll0 = session.run( ll, { matchups: m_modified[-1], field: f, scores: s, pairing_counts: p, record_counts: n }) d0 = math.sqrt(np.square(f0 - mf).sum()) dist_matchup.append(d0) ll_modified_matchup.append(ll0) results.append((ll0, m_modified[-1], d0)) for i in range(1000): free = np.random.beta(12, 12, 3) m_modified.append(session.run(mm, {m_free: free})) ll_modified = session.run( ll, { matchups: m_modified[-1], field: f, scores: s, pairing_counts: p, record_counts: n }) if ll_modified > ll_true: print(ll_modified, m_modified[-1]) dist = math.sqrt(np.square(free - mf).sum()) dist_matchup.append(dist) ll_modified_matchup.append(ll_modified) results.append((ll_modified, m_modified[-1], dist)) results.sort(key=lambda x: x[0]) for i in range(10): print(results[i]) print('...') for i in range(10): print(results[i - 10]) dist_matchup = np.array(dist_matchup) ll_modified_matchup = np.array(ll_modified_matchup) ax = sns.scatterplot(dist_matchup, ll_modified_matchup - ll_true) ax.hlines(0.0, xmin=-1.0, xmax=1.0, linestyle='dashed') ax.vlines(0.0, ymin=-1.0, ymax=1.0, linestyle='dashed') plt.show()