def compute_pragmatic_speaker(self, literal_matrix, rationality=1.0, speaker_prior=False, lm_logprobsf=None, return_diagnostics=False): """ Do the normalization over logprob matrix literal_matrix: [num_distractor_images+1, captions] :param literal_matrix: should be [I, C] (num_images, num_captions) Or [I, Vocab] (num_images, vocab_size) :param speaker_prior: turn on, we default to adding literal matrix :param speaker_prior_lm_mat: [I, Vocab] (a grammar weighting for previous tokens) :return: A re-weighted matrix [I, C/Vocab] """ # step 1 pass # step 2 s0 = literal_matrix.clone() norm_const = logsumexp(literal_matrix, dim=0, keepdim=True) l1 = literal_matrix.clone() - norm_const # step 3 l1 *= rationality # step 4 if speaker_prior: # we add speaker prior # this needs to be a LM with shared vocabulary if lm_logprobsf is not None: s1 = l1 + lm_logprobsf[0] else: s1 = l1 + s0 # step 5 norm_const = logsumexp(s1, dim=1, keepdim=True) # row normalization s1 = s1 - norm_const if return_diagnostics: return s1, l1, s0 return s1
def compute_pragmatic_speaker_w_similarity(self, literal_matrix, num_similar_images, rationality=1.0, speaker_prior=False, lm_logprobsf=None, entropy_penalty_alpha=0.0, return_diagnostics=False): s0_mat = literal_matrix prior = s0_mat.clone()[0] l1_mat = s0_mat - logsumexp(s0_mat, dim=0, keepdim=True) same_cell_prob_mat = l1_mat[:num_similar_images + 1] - logsumexp(l1_mat[:num_similar_images + 1], dim=0) l1_qud_mat = same_cell_prob_mat.clone() entropy = self.compute_entropy(same_cell_prob_mat, 0, keepdim=True) # (1, |V|) utility_2 = entropy utility_1 = logsumexp(l1_mat[:num_similar_images + 1], dim=0, keepdim=True) # [1, |V|] utility = (1 - entropy_penalty_alpha) * utility_1 + entropy_penalty_alpha * utility_2 s1 = utility * rationality # apply rationality if speaker_prior: if lm_logprobsf is None: s1 += prior else: s1 += lm_logprobsf[0] # lm rows are all the same # here is two rows summation if return_diagnostics: # We return RSA-terms only; on the oustide (Debugger), we re-assemble for snapshots of computational process # s0, L1, u1, L1*, u2, u1+u2, s1 # mat, vec, vec, mat, vec, vec, vec return s0_mat, l1_mat, utility_1, l1_qud_mat, entropy, utility_2, utility, s1 - logsumexp(s1, dim=1, keepdim=True) return s1 - logsumexp(s1, dim=1, keepdim=True)
def check_s1_row_normalized(self, s1_mat): rand_time_idx = np.random.randint(len(s1_mat)) print("S0 - The following value should be 1:", torch.exp(logsumexp(s1_mat[rand_time_idx][0])))
def check_l1_qud_column_normalized(self, l1_qud_mat): rand_time_idx = np.random.randint(len(l1_qud_mat)) print("L1 QuD - The following value should be 1:", torch.exp(logsumexp(l1_qud_mat[rand_time_idx][:, 0])))
def check_l1_column_stochastic(self, l1_mat): rand_time_idx = np.random.randint(len(l1_mat)) print("L1 - The following value should be 1:", torch.exp(logsumexp(l1_mat[rand_time_idx][:, 0])))
def check_s0_row_stochastic(self, s0_mat): rand_time_idx = np.random.randint(len(s0_mat)) print("S0 - The following value should be 1:", torch.exp(logsumexp(s0_mat[rand_time_idx][0])))