def diff(self): self.outside_table = self.batch_outside(self.inside_table, self.crf_scores) counts = self.inside_table[1] + self.outside_table[1] pseudo_count = torch.DoubleTensor(self.batch_size, self.sentence_length, self.sentence_length, self.tag_num, self.tag_num) pseudo_count.fill_(LOGZERO) if torch.cuda.is_available(): pseudo_count = pseudo_count.cuda() span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index( self.sentence_length, False) for l in range(self.sentence_length): for r in range(self.sentence_length): for dir in range(2): span_id = span_2_id.get((l, r, dir)) if span_id is not None: if dir == 0: pseudo_count[:, r, l, :, :] = counts[:, span_id, :, :].permute( 0, 2, 1) else: pseudo_count[:, l, r, :, :] = counts[:, span_id, :, :] mius = pseudo_count - self.partition_score.contiguous().view( self.batch_size, 1, 1, 1, 1) diff = torch.exp(mius) # if self.mask is not None: # diff = diff.masked_fill_(self.mask, 0.0) return diff
def batch_inside(batch_scores, batch_decision_score, valency_num, cvalency_num): batch_size, sentence_length, _, tag_num, _, _ = batch_scores.shape inside_complete_table = np.zeros((batch_size, sentence_length * sentence_length * 2, tag_num, valency_num)) inside_incomplete_table = np.zeros( (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num, valency_num)) span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length, False) inside_complete_table.fill(-np.inf) inside_incomplete_table.fill(-np.inf) for ii in basic_span: (i, i, dir) = id_2_span[ii] inside_complete_table[:, ii, :, :] = batch_decision_score[:, i, :, dir, :, 0] for ij in ijss: (l, r, dir) = id_2_span[ij] # two complete span to form an incomplete span num_ki = len(ikis[ij]) inside_ik_ci = inside_complete_table[:, ikis[ij], :, :].reshape(batch_size, num_ki, tag_num, 1, valency_num) inside_kj_ci = inside_complete_table[:, kjis[ij], :, :].reshape(batch_size, num_ki, 1, tag_num, valency_num) if dir == 0: span_inside_i = inside_ik_ci[:, :, :, :, 0].reshape(batch_size, num_ki, tag_num, 1, 1) \ + inside_kj_ci[:, :, :, :, 1].reshape(batch_size, num_ki, 1, tag_num, 1) \ + batch_scores[:, r, l, :, :, :].swapaxes(2, 1).reshape(batch_size, 1, tag_num, tag_num, cvalency_num) \ + batch_decision_score[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num) # swap head-child to left-right position else: span_inside_i = inside_ik_ci[:, :, :, :, 1].reshape(batch_size, num_ki, tag_num, 1, 1) \ + inside_kj_ci[:, :, :, :, 0].reshape(batch_size, num_ki, 1, tag_num, 1) \ + batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) \ + batch_decision_score[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num) inside_incomplete_table[:, ij, :, :, :] = logsumexp(span_inside_i, axis=1) # one complete span and one incomplete span to form bigger complete span num_kc = len(ikcs[ij]) if dir == 0: inside_ik_cc = inside_complete_table[:, ikcs[ij], :, :].reshape(batch_size, num_kc, tag_num, 1, valency_num) inside_kj_ic = inside_incomplete_table[:, kjcs[ij], :, :, :].reshape(batch_size, num_kc, tag_num, tag_num, valency_num) span_inside_c = inside_ik_cc[:, :, :, :, 0].reshape(batch_size, num_kc, tag_num, 1, 1) + inside_kj_ic span_inside_c = span_inside_c.reshape(batch_size, num_kc * tag_num, tag_num, valency_num) inside_complete_table[:, ij, :, :] = logsumexp(span_inside_c, axis=1) else: inside_ik_ic = inside_incomplete_table[:, ikcs[ij], :, :, :].reshape(batch_size, num_kc, tag_num, tag_num, valency_num) inside_kj_cc = inside_complete_table[:, kjcs[ij], :, :].reshape(batch_size, num_kc, 1, tag_num, valency_num) span_inside_c = inside_ik_ic + inside_kj_cc[:, :, :, :, 0].reshape(batch_size, num_kc, 1, tag_num, 1) span_inside_c = span_inside_c.swapaxes(3, 2).reshape(batch_size, num_kc * tag_num, tag_num, valency_num) # swap the left-right position since the left tags are to be indexed inside_complete_table[:, ij, :, :] = logsumexp(span_inside_c, axis=1) final_id = span_2_id[(0, sentence_length - 1, 1)] partition_score = inside_complete_table[:, final_id, 0, 0] return inside_complete_table, inside_incomplete_table, partition_score
def batch_parse(batch_scores): batch_size, sentence_length, _ = batch_scores.shape # CYK table complete_table = torch.zeros( (batch_size, sentence_length * sentence_length * 2), dtype=torch.float) incomplete_table = torch.zeros( (batch_size, sentence_length * sentence_length * 2), dtype=torch.float) # backtrack table complete_backtrack = -torch.ones( (batch_size, sentence_length * sentence_length * 2), dtype=torch.int) incomplete_backtrack = -torch.ones( (batch_size, sentence_length * sentence_length * 2), dtype=torch.int) # span index table, to avoid redundant iterations span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index( sentence_length, False) # initial basic complete spans for ii in basic_span: complete_table[:, ii] = 0.0 for ij in ijss: (l, r, dir) = id_2_span[ij] num_ki = len(ikis[ij]) ik_ci = complete_table[:, ikis[ij]].reshape(batch_size, num_ki) kj_ci = complete_table[:, kjis[ij]].reshape(batch_size, num_ki) # construct incomplete spans if dir == 0: span_i = ik_ci + kj_ci + batch_scores[:, r, l].reshape( batch_size, 1) else: span_i = ik_ci + kj_ci + batch_scores[:, l, r].reshape( batch_size, 1) incomplete_table[:, ij] = torch.max(span_i, dim=1)[0] max_idx = torch.max(span_i, dim=1)[1] incomplete_backtrack[:, ij] = max_idx num_kc = len(ikcs[ij]) if dir == 0: ik_cc = complete_table[:, ikcs[ij]].reshape(batch_size, num_kc) kj_ic = incomplete_table[:, kjcs[ij]].reshape(batch_size, num_kc) span_c = ik_cc + kj_ic else: ik_ic = incomplete_table[:, ikcs[ij]].reshape(batch_size, num_kc) kj_cc = complete_table[:, kjcs[ij]].reshape(batch_size, num_kc) span_c = ik_ic + kj_cc complete_table[:, ij] = torch.max(span_c, dim=1)[0] max_idx = torch.max(span_c, dim=1)[1] complete_backtrack[:, ij] = max_idx heads = -torch.ones((batch_size, sentence_length), dtype=torch.long) root_id = span_2_id[(0, sentence_length - 1, 1)] for s in range(batch_size): batch_backtracking(incomplete_backtrack, complete_backtrack, root_id, 1, heads, ikcs, ikis, kjcs, kjis, id_2_span, span_2_id, s) return heads
def diff(outside_table,inside_table,batch_size,sentence_length,tag_num,partition_score): counts = inside_table[1] + outside_table[1] pseudo_count = torch.DoubleTensor(batch_size, sentence_length, sentence_length, tag_num, tag_num) pseudo_count.fill_(0.0) span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length) for l in range(sentence_length): for r in range(sentence_length): for dir in range(2): span_id = span_2_id.get((l, r, dir)) if span_id is not None: if dir == 0: pseudo_count[:, r, l, :, :] = counts[:, span_id, :, :] else: pseudo_count[:, l, r, :, :] = counts[:, span_id, :, :] mius = pseudo_count - partition_score.contiguous().view(batch_size, 1, 1, 1, 1) diff = torch.exp(mius) #if mask is not None: #diff = diff.masked_fill_(mask, 0.0) return diff
def batch_parse(batch_scores, batch_decision_score, valency_num, cvalency_num): batch_size, sentence_length, _, tag_num, _, _ = batch_scores.shape # CYK table complete_table = np.zeros( (batch_size, sentence_length * sentence_length * 2, tag_num, valency_num)) incomplete_table = np.zeros( (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num, valency_num)) complete_table.fill(-np.inf) incomplete_table.fill(-np.inf) # backtrack table complete_backtrack = -np.ones( (batch_size, sentence_length * sentence_length * 2, tag_num, valency_num), dtype=int) incomplete_backtrack = -np.ones( (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num, valency_num), dtype=int) # span index table, to avoid redundant iterations span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index( sentence_length, False) # initial basic complete spans for ii in basic_span: (i, i, dir) = id_2_span[ii] complete_table[:, ii, :, :] = batch_decision_score[:, i, :, dir, :, 0] for ij in ijss: (l, r, dir) = id_2_span[ij] num_ki = len(ikis[ij]) ik_ci = complete_table[:, ikis[ij], :, :].reshape(batch_size, num_ki, tag_num, 1, valency_num) kj_ci = complete_table[:, kjis[ij], :, :].reshape(batch_size, num_ki, 1, tag_num, valency_num) # construct incomplete spans if dir == 0: span_i = ik_ci[:, :, :, :, 0].reshape(batch_size, num_ki, tag_num, 1, 1) \ + kj_ci[:, :, :, :, 1].reshape(batch_size, num_ki, 1, tag_num, 1) + \ batch_scores[:, r, l, :, :, :].swapaxes(1, 2).reshape(batch_size, 1, tag_num, tag_num, cvalency_num) \ + batch_decision_score[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num) else: span_i = ik_ci[:, :, :, :, 1].reshape(batch_size, num_ki, tag_num, 1, 1) \ + kj_ci[:, :, :, :, 0].reshape(batch_size, num_ki, 1, tag_num, 1) + \ batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) \ + batch_decision_score[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num) max = np.max(span_i, axis=1) incomplete_table[:, ij, :, :, :] = np.max(span_i, axis=1) max_idx = np.argmax(span_i, axis=1) incomplete_backtrack[:, ij, :, :, :] = max_idx # construct complete spans num_kc = len(ikcs[ij]) if dir == 0: ik_cc = complete_table[:, ikcs[ij], :, :].reshape( batch_size, num_kc, tag_num, 1, valency_num) kj_ic = incomplete_table[:, kjcs[ij], :, :, :].reshape( batch_size, num_kc, tag_num, tag_num, valency_num) span_c = ik_cc[:, :, :, :, 0].reshape(batch_size, num_kc, tag_num, 1, 1) + kj_ic span_c = span_c.reshape(batch_size, num_kc * tag_num, tag_num, valency_num) else: ik_ic = incomplete_table[:, ikcs[ij], :, :, :].reshape( batch_size, num_kc, tag_num, tag_num, valency_num) kj_cc = complete_table[:, kjcs[ij], :, :].reshape( batch_size, num_kc, 1, tag_num, valency_num) span_c = ik_ic + kj_cc[:, :, :, :, 0].reshape( batch_size, num_kc, 1, tag_num, 1) span_c = span_c.swapaxes(2, 3).reshape(batch_size, num_kc * tag_num, tag_num, valency_num) complete_table[:, ij, :, :] = np.max(span_c, axis=1) max_idx = np.argmax(span_c, axis=1) complete_backtrack[:, ij, :, :] = max_idx tags = np.zeros((batch_size, sentence_length)).astype(int) heads = -np.ones((batch_size, sentence_length)) head_valences = np.zeros((batch_size, sentence_length)) valences = np.zeros((batch_size, sentence_length, 2)) root_id = span_2_id[(0, sentence_length - 1, 1)] for s in range(batch_size): batch_backtracking(incomplete_backtrack, complete_backtrack, root_id, 0, 0, 0, 1, tags, heads, head_valences, valences, ikcs, ikis, kjcs, kjis, id_2_span, span_2_id, tag_num, s) return (heads, tags, head_valences, valences)
def batch_outside(inside_complete_table, inside_incomplete_table, batch_scores, batch_decision_scores, valency_num, cvalency_num): batch_size, sentence_length, _, tag_num, _, _ = batch_scores.shape outside_complete_table = np.zeros( (batch_size, sentence_length * sentence_length * 2, tag_num, valency_num)) outside_incomplete_table = np.zeros( (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num, valency_num)) span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index( sentence_length, False) outside_complete_table.fill(-np.inf) outside_incomplete_table.fill(-np.inf) root_id = span_2_id.get((0, sentence_length - 1, 1)) outside_complete_table[:, root_id, 0, 0] = 0.0 complete_span_used_0 = set() complete_span_used_1 = set() incomplete_span_used = set() complete_span_used_0.add(root_id) for ij in reversed(ijss): (l, r, dir) = id_2_span[ij] # complete span consists of one incomplete span and one complete span num_kc = len(ikcs[ij]) if dir == 0: outside_ij_cc = outside_complete_table[:, ij, :, :].reshape( batch_size, 1, 1, tag_num, valency_num) inside_kj_ic = inside_incomplete_table[:, kjcs[ij], :, :, :].reshape( batch_size, num_kc, tag_num, tag_num, valency_num) inside_ik_cc = inside_complete_table[:, ikcs[ij], :, :].reshape( batch_size, num_kc, tag_num, 1, valency_num) outside_ik_cc = (outside_ij_cc + inside_kj_ic).swapaxes(2, 3) # swap left-right position since right tags are to be indexed outside_kj_ic = outside_ij_cc + inside_ik_cc[:, :, :, :, 0].reshape( batch_size, num_kc, tag_num, 1, 1) for i in range(num_kc): ik = ikcs[ij][i] kj = kjcs[ij][i] outside_ik_cc_i = logsumexp(outside_ik_cc[:, i, :, :, :], axis=(1, 3)) if ik in complete_span_used_0: outside_complete_table[:, ik, :, 0] = np.logaddexp( outside_complete_table[:, ik, :, 0], outside_ik_cc_i) else: outside_complete_table[:, ik, :, 0] = np.copy(outside_ik_cc_i) complete_span_used_0.add(ik) if kj in incomplete_span_used: outside_incomplete_table[:, kj, :, :, :] = np.logaddexp( outside_incomplete_table[:, kj, :, :, :], outside_kj_ic[:, i, :, :, :]) else: outside_incomplete_table[:, kj, :, :, :] = np.copy( outside_kj_ic[:, i, :, :, :]) incomplete_span_used.add(kj) else: outside_ij_cc = outside_complete_table[:, ij, :, :].reshape( batch_size, 1, tag_num, 1, valency_num) inside_ik_ic = inside_incomplete_table[:, ikcs[ij], :, :, :].reshape( batch_size, num_kc, tag_num, tag_num, valency_num) inside_kj_cc = inside_complete_table[:, kjcs[ij], :, :].reshape( batch_size, num_kc, 1, tag_num, valency_num) outside_kj_cc = outside_ij_cc + inside_ik_ic outside_ik_ic = outside_ij_cc + inside_kj_cc[:, :, :, :, 0].reshape( batch_size, num_kc, 1, tag_num, 1) for i in range(num_kc): kj = kjcs[ij][i] ik = ikcs[ij][i] outside_kj_cc_i = logsumexp(outside_kj_cc[:, i, :, :, :], axis=(1, 3)) if kj in complete_span_used_0: outside_complete_table[:, kj, :, 0] = np.logaddexp( outside_complete_table[:, kj, :, 0], outside_kj_cc_i) else: outside_complete_table[:, kj, :, 0] = np.copy(outside_kj_cc_i) complete_span_used_0.add(kj) if ik in incomplete_span_used: outside_incomplete_table[:, ik, :, :, :] = np.logaddexp( outside_incomplete_table[:, ik, :, :, :], outside_ik_ic[:, i, :, :, :]) else: outside_incomplete_table[:, ik, :, :, :] = np.copy( outside_ik_ic[:, i, :, :, :]) incomplete_span_used.add(ik) # incomplete span consists of two complete spans num_ki = len(ikis[ij]) outside_ij_ii = outside_incomplete_table[:, ij, :, :, :].reshape( batch_size, 1, tag_num, tag_num, valency_num) inside_ik_ci = inside_complete_table[:, ikis[ij], :, :].reshape( batch_size, num_ki, tag_num, 1, valency_num) inside_kj_ci = inside_complete_table[:, kjis[ij], :].reshape( batch_size, num_ki, 1, tag_num, valency_num) if dir == 0: outside_ik_ci_0 = outside_ij_ii + inside_kj_ci[:, :, :, :, 1].reshape(batch_size, num_ki, 1, tag_num, 1) + \ batch_scores[:, r, l, :, :, :].swapaxes(1, 2). \ reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \ batch_decision_scores[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num) outside_kj_ci_1 = outside_ij_ii + inside_ik_ci[:, :, :, :, 0].reshape(batch_size, num_ki, tag_num, 1, 1) + \ batch_scores[:, r, l, :, :, :].swapaxes(1, 2). \ reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \ batch_decision_scores[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num) else: outside_ik_ci_1 = outside_ij_ii + inside_kj_ci[:, :, :, :, 0].reshape(batch_size, num_ki, 1, tag_num, 1) \ + batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \ batch_decision_scores[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num) outside_kj_ci_0 = outside_ij_ii + inside_ik_ci[:, :, :, :, 1].reshape(batch_size, num_ki, tag_num, 1, 1) + \ batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \ batch_decision_scores[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num) for i in range(num_ki): ik = ikis[ij][i] kj = kjis[ij][i] if dir == 0: outside_ik_ci_i_0 = logsumexp(outside_ik_ci_0[:, i, :, :, :], axis=(2, 3)) outside_kj_ci_i_1 = logsumexp(outside_kj_ci_1[:, i, :, :, :], axis=(1, 3)) else: outside_ik_ci_i_1 = logsumexp(outside_ik_ci_1[:, i, :, :, :], axis=(2, 3)) outside_kj_ci_i_0 = logsumexp(outside_kj_ci_0[:, i, :, :, :], axis=(1, 3)) if dir == 0: if ik in complete_span_used_0: outside_complete_table[:, ik, :, 0] = np.logaddexp( outside_complete_table[:, ik, :, 0], outside_ik_ci_i_0) else: outside_complete_table[:, ik, :, 0] = np.copy(outside_ik_ci_i_0) complete_span_used_0.add(ik) if kj in complete_span_used_1: outside_complete_table[:, kj, :, 1] = np.logaddexp( outside_complete_table[:, kj, :, 1], outside_kj_ci_i_1) else: outside_complete_table[:, kj, :, 1] = np.copy(outside_kj_ci_i_1) complete_span_used_1.add(kj) else: if ik in complete_span_used_1: outside_complete_table[:, ik, :, 1] = np.logaddexp( outside_complete_table[:, ik, :, 1], outside_ik_ci_i_1) else: outside_complete_table[:, ik, :, 1] = np.copy(outside_ik_ci_i_1) complete_span_used_1.add(ik) if kj in complete_span_used_0: outside_complete_table[:, kj, :, 0] = np.logaddexp( outside_complete_table[:, kj, :, 0], outside_kj_ci_i_0) else: outside_complete_table[:, kj, :, 0] = np.copy(outside_kj_ci_i_0) complete_span_used_0.add(kj) return outside_complete_table, outside_incomplete_table
def update_pseudo_count(self, inside_incomplete_table, inside_complete_table, sentence_prob, outside_incomplete_table, outside_complete_table, trans_counter, root_counter, decision_counter, batch_pos, batch_sen, batch_lan): batch_likelihood = 0.0 batch_size, sentence_length = batch_pos.shape span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length, False) for s in range(batch_size): pos_sentence = batch_pos[s] sentence_id = batch_sen[s] language_id = batch_lan[s] one_sentence_count = [] one_sentence_decision_count = [] for h in range(sentence_length): for m in range(sentence_length): if m == 0: continue if h == m: continue if h > m: dir = 0 else: dir = 1 h_pos = pos_sentence[h] m_pos = pos_sentence[m] m_dec_pos = m_pos if dir == 0: span_id = span_2_id[(m, h, dir)] else: span_id = span_2_id[(h, m, dir)] dep_count = inside_incomplete_table[s, span_id, :, :, :] + \ outside_incomplete_table[s, span_id, :, :, :] - sentence_prob[s] if dir == 0: dep_count = dep_count.swapaxes(1, 0) if self.cvalency == 1: if h == 0: root_counter[m_pos] += np.sum(np.exp(dep_count)) else: trans_counter[h_pos, m_pos, dir, 0] += np.sum(np.exp(dep_count)) else: if h == 0: root_counter[m_pos] += np.sum(np.exp(dep_count)) else: trans_counter[h_pos, m_pos, dir, :] += np.exp(dep_count).reshape(self.cvalency) if self.use_neural: for v in range(self.cvalency): count = np.exp(dep_count).reshape(self.cvalency)[v] if not h == 0: self.rule_samples.append(list([h_pos, m_pos, dir, v, sentence_id, language_id, count])) if h > 0: h_dec_pos = h_pos decision_counter[h_dec_pos, dir, :, 1] += np.exp(dep_count).reshape(self.cvalency) if self.use_neural: reshaped_count = np.exp(dep_count).reshape(self.cvalency) for v in range(self.dvalency): count = reshaped_count[v] if not self.unified_network: self.decision_samples.append( list([h_dec_pos, dir, v, sentence_id, language_id, 1, count])) else: self.decision_samples.append( list([h_pos, dir, v, sentence_id, language_id, 1, count])) for m in range(1, sentence_length): m_pos = pos_sentence[m] m_dec_pos = m_pos for d in range(2): m_span_id = span_2_id[(m, m, d)] stop_count = inside_complete_table[s, m_span_id, :, :] + \ outside_complete_table[s, m_span_id, :, :] - sentence_prob[s] decision_counter[m_dec_pos, d, :, 0] += np.exp(stop_count).reshape(self.cvalency) if self.use_neural: for v in range(self.dvalency): count = np.exp(stop_count).reshape(self.cvalency)[v] if not self.unified_network: self.decision_samples.append( list([m_dec_pos, d, v, sentence_id, language_id, 0, count])) else: self.decision_samples.append(list([m_pos, d, v, 0, sentence_id, language_id, count])) batch_likelihood += sentence_prob[s] self.sentence_counter[sentence_id] = one_sentence_count self.sentence_decision_counter[sentence_id] = one_sentence_decision_count return batch_likelihood
def update_pseudo_count(self, inside_incomplete_table, inside_complete_table, sentence_prob, outside_incomplete_table, outside_complete_table, trans_counter, decision_counter, lex_counter, batch_pos, batch_words): batch_likelihood = 0.0 batch_size, sentence_length = batch_pos.shape span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length, False) for sen_id in range(batch_size): pos_sentence = batch_pos[sen_id] word_sentence = batch_words[sen_id] for h in range(sentence_length): for m in range(sentence_length): if m == 0: continue if h == m: continue if h > m: dir = 0 else: dir = 1 h_pos = pos_sentence[h] m_pos = pos_sentence[m] m_dec_pos = self.to_decision[m_pos] m_word = word_sentence[m] if dir == 0: span_id = span_2_id[(m, h, dir)] else: span_id = span_2_id[(h, m, dir)] dep_count = inside_incomplete_table[sen_id, span_id, :, :, :] + \ outside_incomplete_table[sen_id, span_id, :, :, :] - sentence_prob[sen_id] if dir == 0: dep_count = dep_count.swapaxes(1, 0) if self.cvalency == 1: trans_counter[h_pos, m_pos, :, :, dir, 0] += np.sum(np.exp(dep_count), axis=2) else: trans_counter[h_pos, m_pos, :, :, dir, :] += np.exp(dep_count) if self.use_neural: for h_tag_id in range(self.tag_num): for m_tag_id in range(self.tag_num): for v in range(self.cvalency): count = np.exp(dep_count[h_tag_id, m_tag_id, v]) self.rule_samples.append(list([h_pos, m_pos, h_tag_id, m_tag_id, dir, v, count])) if h > 0: h_dec_pos = self.to_decision[h_pos] decision_counter[h_dec_pos, :, dir, :, 1] += np.sum(np.exp(dep_count), axis=1) if self.use_neural: summed_count = np.sum(np.exp(dep_count), axis=1) for h_tag_id in range(self.tag_num): for v in range(self.dvalency): count = summed_count[h_tag_id, v] if not self.unified_network: self.decision_samples.append(list([h_dec_pos, h_tag_id, dir, v, 1, count])) else: self.decision_samples.append(list([h_pos, h_tag_id, dir, v, 1, count])) if self.use_lex: lex_counter[m_pos, :, m_word] += np.sum(np.exp(dep_count), axis=(0, 2)) for m in range(1, sentence_length): m_pos = pos_sentence[m] m_dec_pos = self.to_decision[m_pos] m_word = word_sentence[m] for d in range(2): m_span_id = span_2_id[(m, m, d)] stop_count = inside_complete_table[sen_id, m_span_id, :, :] + \ outside_complete_table[sen_id, m_span_id, :, :] - sentence_prob[sen_id] decision_counter[m_dec_pos, :, d, :, 0] += np.exp(stop_count) if self.use_neural: for m_tag_id in range(self.tag_num): for v in range(self.dvalency): count = np.exp(stop_count[m_tag_id, v]) if not self.unified_network: self.decision_samples.append(list([m_dec_pos, m_tag_id, d, v, 0, count])) else: self.decision_samples.append(list([m_pos, m_tag_id, d, v, 0, count])) batch_likelihood += sentence_prob[sen_id] return batch_likelihood
def update_pseudo_count(self, inside_incomplete_table, inside_complete_table, sentence_prob, outside_incomplete_table, outside_complete_table, trans_counter, decision_counter, batch_pos, batch_sen, batch_lan): batch_likelihood = 0.0 en_like = 0.0 batch_size, sentence_length = batch_pos.shape span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index( sentence_length, False) for s in range(batch_size): pos_sentence = batch_pos[s] sentence_id = batch_sen[s] lan_id = batch_lan[s] one_sentence_count = [] one_sentence_decision_count = [] for h in range(sentence_length): for m in range(sentence_length): if m == 0: continue if h == m: continue if h > m: dir = 0 else: dir = 1 h_pos = pos_sentence[h] m_pos = pos_sentence[m] if dir == 0: span_id = span_2_id[(m, h, dir)] else: span_id = span_2_id[(h, m, dir)] # Pseudo count for one dependency arc dep_count = inside_incomplete_table[s, span_id, :, :, :] + \ outside_incomplete_table[s, span_id, :, :, :] - sentence_prob[s] if dir == 0: dep_count = dep_count.swapaxes(1, 0) if self.cvalency == 1: trans_counter[h_pos, m_pos, dir, 0, lan_id] += np.sum(np.exp(dep_count)) else: trans_counter[h_pos, m_pos, dir, :, lan_id] += np.exp(dep_count).reshape( self.dvalency) # Add training samples for neural network if self.use_neural: for v in range(self.cvalency): count = np.exp(dep_count).reshape(self.dvalency)[v] self.rule_samples.append( list([ h_pos, m_pos, dir, v, sentence_id, lan_id, count ])) if h > 0: # Add count for CONTINUE decision decision_counter[h_pos, dir, :, 1, lan_id] += np.exp(dep_count).reshape( self.dvalency) if self.use_neural: reshaped_count = np.exp(dep_count).reshape( self.dvalency) for v in range(self.dvalency): count = reshaped_count[v] self.decision_samples.append( list([ h_pos, dir, v, sentence_id, lan_id, 1, count ])) for m in range(1, sentence_length): m_pos = pos_sentence[m] for d in range(2): m_span_id = span_2_id[(m, m, d)] # Pseudo count for STOP decision stop_count = inside_complete_table[s, m_span_id, :, :] + \ outside_complete_table[s, m_span_id, :, :] - sentence_prob[s] decision_counter[m_pos, d, :, 0, lan_id] += np.exp(stop_count).reshape( self.dvalency) if self.use_neural: for v in range(self.dvalency): count = np.exp(stop_count).reshape( self.dvalency)[v] self.decision_samples.append( list([ m_pos, d, v, sentence_id, lan_id, 0, count ])) batch_likelihood += sentence_prob[s] if self.language_map[sentence_id] == 'en': en_like += sentence_prob[s] if self.sentence_predict: self.sentence_counter[sentence_id] = one_sentence_count self.sentence_decision_counter[ sentence_id] = one_sentence_decision_count return batch_likelihood, en_like
def batch_inside(self, crf_scores): inside_complete_table = torch.DoubleTensor( self.batch_size, self.sentence_length * self.sentence_length * 2, self.tag_num) inside_incomplete_table = torch.DoubleTensor( self.batch_size, self.sentence_length * self.sentence_length * 2, self.tag_num, self.tag_num) if torch.cuda.is_available(): inside_complete_table = inside_complete_table.cuda() inside_incomplete_table = inside_incomplete_table.cuda() span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index( self.sentence_length, False) inside_complete_table.fill_(LOGZERO) inside_incomplete_table.fill_(LOGZERO) for ii in basic_span: inside_complete_table[:, ii, :] = 0.0 for ij in ijss: (l, r, dir) = id_2_span[ij] # two complete span to form an incomplete span num_ki = len(ikis[ij]) inside_ik_ci = inside_complete_table[:, ikis[ij], :].contiguous( ).view(self.batch_size, num_ki, self.tag_num, 1) inside_kj_ci = inside_complete_table[:, kjis[ij], :].contiguous( ).view(self.batch_size, num_ki, 1, self.tag_num) if dir == 0: span_inside_i = inside_ik_ci + inside_kj_ci + crf_scores[:, r, l, :, :] \ .permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num) # swap head-child to left-right position else: span_inside_i = inside_ik_ci + inside_kj_ci + crf_scores[:, l, r, :, :].contiguous( ).view(self.batch_size, 1, self.tag_num, self.tag_num) inside_incomplete_table[:, ij, :, :] = utils.logsumexp(span_inside_i, axis=1) # one complete span and one incomplete span to form bigger complete span num_kc = len(ikcs[ij]) if dir == 0: inside_ik_cc = inside_complete_table[:, ikcs[ij], :].contiguous( ).view( self.batch_size, num_kc, self.tag_num, 1) inside_kj_ic = inside_incomplete_table[:, kjcs[ ij], :, :].contiguous().view(self.batch_size, num_kc, self.tag_num, self.tag_num) span_inside_c = inside_ik_cc + inside_kj_ic span_inside_c = span_inside_c.contiguous().view( self.batch_size, num_kc * self.tag_num, self.tag_num) inside_complete_table[:, ij, :] = utils.logsumexp(span_inside_c, axis=1) else: inside_ik_ic = inside_incomplete_table[:, ikcs[ ij], :, :].contiguous().view(self.batch_size, num_kc, self.tag_num, self.tag_num) inside_kj_cc = inside_complete_table[:, kjcs[ij], :].contiguous( ).view( self.batch_size, num_kc, 1, self.tag_num) span_inside_c = inside_ik_ic + inside_kj_cc span_inside_c = span_inside_c.permute( 0, 1, 3, 2).contiguous().view(self.batch_size, num_kc * self.tag_num, self.tag_num) # swap the left-right position since the left tags are to be indexed inside_complete_table[:, ij, :] = utils.logsumexp(span_inside_c, axis=1) final_id = span_2_id[(0, self.sentence_length - 1, 1)] partition_score = inside_complete_table[:, final_id, 0] return (inside_complete_table, inside_incomplete_table), partition_score
def batch_outside(self, inside_table, crf_score): inside_complete_table = inside_table[0] inside_incomplete_table = inside_table[1] outside_complete_table = torch.DoubleTensor( self.batch_size, self.sentence_length * self.sentence_length * 2, self.tag_num) outside_incomplete_table = torch.DoubleTensor( self.batch_size, self.sentence_length * self.sentence_length * 2, self.tag_num, self.tag_num) if torch.cuda.is_available(): outside_complete_table = outside_complete_table.cuda() outside_incomplete_table = outside_incomplete_table.cuda() span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index( self.sentence_length, False) outside_complete_table.fill_(LOGZERO) outside_incomplete_table.fill_(LOGZERO) root_id = span_2_id.get((0, self.sentence_length - 1, 1)) outside_complete_table[:, root_id, 0] = 0.0 complete_span_used = set() incomplete_span_used = set() complete_span_used.add(root_id) for ij in reversed(ijss): (l, r, dir) = id_2_span[ij] # complete span consists of one incomplete span and one complete span num_kc = len(ikcs[ij]) if dir == 0: outside_ij_cc = outside_complete_table[:, ij, :].contiguous( ).view(self.batch_size, 1, 1, self.tag_num) inside_kj_ic = inside_incomplete_table[:, kjcs[ ij], :, :].contiguous().view(self.batch_size, num_kc, self.tag_num, self.tag_num) inside_ik_cc = inside_complete_table[:, ikcs[ij], :].contiguous( ).view( self.batch_size, num_kc, self.tag_num, 1) outside_ik_cc = (outside_ij_cc + inside_kj_ic).permute( 0, 1, 3, 2) # swap left-right position since right tags are to be indexed outside_kj_ic = outside_ij_cc + inside_ik_cc for i in range(num_kc): ik = ikcs[ij][i] kj = kjcs[ij][i] outside_ik_cc_i = utils.logsumexp(outside_ik_cc[:, i, :, :], axis=1) if ik in complete_span_used: outside_complete_table[:, ik, :] = utils.logaddexp( outside_complete_table[:, ik, :], outside_ik_cc_i) else: outside_complete_table[:, ik, :] = outside_ik_cc_i.clone( ) complete_span_used.add(ik) if kj in incomplete_span_used: outside_incomplete_table[:, kj, :, :] = utils.logaddexp( outside_incomplete_table[:, kj, :, :], outside_kj_ic[:, i, :, :]) else: outside_incomplete_table[:, kj, :, :] = outside_kj_ic[:, i, :, :] incomplete_span_used.add(kj) else: outside_ij_cc = outside_complete_table[:, ij, :].contiguous( ).view(self.batch_size, 1, self.tag_num, 1) inside_ik_ic = inside_incomplete_table[:, ikcs[ ij], :, :].contiguous().view(self.batch_size, num_kc, self.tag_num, self.tag_num) inside_kj_cc = inside_complete_table[:, kjcs[ij], :].contiguous( ).view( self.batch_size, num_kc, 1, self.tag_num) outside_kj_cc = outside_ij_cc + inside_ik_ic outside_ik_ic = outside_ij_cc + inside_kj_cc for i in range(num_kc): kj = kjcs[ij][i] ik = ikcs[ij][i] outside_kj_cc_i = utils.logsumexp(outside_kj_cc[:, i, :, :], axis=1) if kj in complete_span_used: outside_complete_table[:, kj, :] = utils.logaddexp( outside_complete_table[:, kj, :], outside_kj_cc_i) else: outside_complete_table[:, kj, :] = outside_kj_cc_i.clone( ) complete_span_used.add(kj) if ik in incomplete_span_used: outside_incomplete_table[:, ik, :, :] = utils.logaddexp( outside_incomplete_table[:, ik, :, :], outside_ik_ic[:, i, :, :]) else: outside_incomplete_table[:, ik, :, :] = outside_ik_ic[:, i, :, :] incomplete_span_used.add(ik) # incomplete span consists of two complete spans num_ki = len(ikis[ij]) outside_ij_ii = outside_incomplete_table[:, ij, :, :].contiguous( ).view(self.batch_size, 1, self.tag_num, self.tag_num) inside_ik_ci = inside_complete_table[:, ikis[ij], :].contiguous( ).view(self.batch_size, num_ki, self.tag_num, 1) inside_kj_ci = inside_complete_table[:, kjis[ij], :].contiguous( ).view(self.batch_size, num_ki, 1, self.tag_num) if dir == 0: outside_ik_ci = outside_ij_ii + inside_kj_ci + crf_score[:, r, l, :, :]. \ permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num) outside_kj_ci = outside_ij_ii + inside_ik_ci + crf_score[:, r, l, :, :]. \ permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num) else: outside_ik_ci = outside_ij_ii + inside_kj_ci + crf_score[:, l, r, :, :].contiguous( ).view(self.batch_size, 1, self.tag_num, self.tag_num) outside_kj_ci = outside_ij_ii + inside_ik_ci + crf_score[:, l, r, :, :].contiguous( ).view(self.batch_size, 1, self.tag_num, self.tag_num) for i in range(num_ki): ik = ikis[ij][i] kj = kjis[ij][i] outside_ik_ci_i = utils.logsumexp(outside_ik_ci[:, i, :, :], axis=2) outside_kj_ci_i = utils.logsumexp(outside_kj_ci[:, i, :, :], axis=1) if ik in complete_span_used: outside_complete_table[:, ik, :] = utils.logaddexp( outside_complete_table[:, ik, :], outside_ik_ci_i) else: outside_complete_table[:, ik, :] = outside_ik_ci_i.clone() complete_span_used.add(ik) if kj in complete_span_used: outside_complete_table[:, kj, :] = utils.logaddexp( outside_complete_table[:, kj, :], outside_kj_ci_i) else: outside_complete_table[:, kj, :] = outside_kj_ci_i.clone() complete_span_used.add(kj) return (outside_complete_table, outside_incomplete_table)