Ejemplo n.º 1
0
    def diff(self):
        self.outside_table = self.batch_outside(self.inside_table,
                                                self.crf_scores)

        counts = self.inside_table[1] + self.outside_table[1]
        pseudo_count = torch.DoubleTensor(self.batch_size,
                                          self.sentence_length,
                                          self.sentence_length, self.tag_num,
                                          self.tag_num)
        pseudo_count.fill_(LOGZERO)
        if torch.cuda.is_available():
            pseudo_count = pseudo_count.cuda()
        span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
            self.sentence_length, False)

        for l in range(self.sentence_length):
            for r in range(self.sentence_length):
                for dir in range(2):
                    span_id = span_2_id.get((l, r, dir))
                    if span_id is not None:
                        if dir == 0:
                            pseudo_count[:, r,
                                         l, :, :] = counts[:,
                                                           span_id, :, :].permute(
                                                               0, 2, 1)
                        else:
                            pseudo_count[:, l, r, :, :] = counts[:,
                                                                 span_id, :, :]

        mius = pseudo_count - self.partition_score.contiguous().view(
            self.batch_size, 1, 1, 1, 1)
        diff = torch.exp(mius)
        # if self.mask is not None:
        #     diff = diff.masked_fill_(self.mask, 0.0)
        return diff
Ejemplo n.º 2
0
def batch_inside(batch_scores, batch_decision_score, valency_num, cvalency_num):
    batch_size, sentence_length, _, tag_num, _, _ = batch_scores.shape
    inside_complete_table = np.zeros((batch_size, sentence_length * sentence_length * 2, tag_num, valency_num))
    inside_incomplete_table = np.zeros(
        (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num, valency_num))
    span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length,
                                                                                             False)
    inside_complete_table.fill(-np.inf)
    inside_incomplete_table.fill(-np.inf)

    for ii in basic_span:
        (i, i, dir) = id_2_span[ii]
        inside_complete_table[:, ii, :, :] = batch_decision_score[:, i, :, dir, :, 0]

    for ij in ijss:
        (l, r, dir) = id_2_span[ij]
        # two complete span to form an incomplete span
        num_ki = len(ikis[ij])
        inside_ik_ci = inside_complete_table[:, ikis[ij], :, :].reshape(batch_size, num_ki, tag_num, 1, valency_num)
        inside_kj_ci = inside_complete_table[:, kjis[ij], :, :].reshape(batch_size, num_ki, 1, tag_num, valency_num)
        if dir == 0:
            span_inside_i = inside_ik_ci[:, :, :, :, 0].reshape(batch_size, num_ki, tag_num, 1, 1) \
                            + inside_kj_ci[:, :, :, :, 1].reshape(batch_size, num_ki, 1, tag_num, 1) \
                            + batch_scores[:, r, l, :, :, :].swapaxes(2, 1).reshape(batch_size, 1, tag_num, tag_num,
                                                                                    cvalency_num) \
                            + batch_decision_score[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num)

            # swap head-child to left-right position
        else:
            span_inside_i = inside_ik_ci[:, :, :, :, 1].reshape(batch_size, num_ki, tag_num, 1, 1) \
                            + inside_kj_ci[:, :, :, :, 0].reshape(batch_size, num_ki, 1, tag_num, 1) \
                            + batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) \
                            + batch_decision_score[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num)

        inside_incomplete_table[:, ij, :, :, :] = logsumexp(span_inside_i, axis=1)

        # one complete span and one incomplete span to form bigger complete span
        num_kc = len(ikcs[ij])
        if dir == 0:
            inside_ik_cc = inside_complete_table[:, ikcs[ij], :, :].reshape(batch_size, num_kc, tag_num, 1, valency_num)
            inside_kj_ic = inside_incomplete_table[:, kjcs[ij], :, :, :].reshape(batch_size, num_kc, tag_num, tag_num,
                                                                                 valency_num)
            span_inside_c = inside_ik_cc[:, :, :, :, 0].reshape(batch_size, num_kc, tag_num, 1, 1) + inside_kj_ic
            span_inside_c = span_inside_c.reshape(batch_size, num_kc * tag_num, tag_num, valency_num)
            inside_complete_table[:, ij, :, :] = logsumexp(span_inside_c, axis=1)
        else:
            inside_ik_ic = inside_incomplete_table[:, ikcs[ij], :, :, :].reshape(batch_size, num_kc, tag_num, tag_num,
                                                                                 valency_num)
            inside_kj_cc = inside_complete_table[:, kjcs[ij], :, :].reshape(batch_size, num_kc, 1, tag_num, valency_num)
            span_inside_c = inside_ik_ic + inside_kj_cc[:, :, :, :, 0].reshape(batch_size, num_kc, 1, tag_num, 1)
            span_inside_c = span_inside_c.swapaxes(3, 2).reshape(batch_size, num_kc * tag_num, tag_num, valency_num)
            # swap the left-right position since the left tags are to be indexed
            inside_complete_table[:, ij, :, :] = logsumexp(span_inside_c, axis=1)

    final_id = span_2_id[(0, sentence_length - 1, 1)]
    partition_score = inside_complete_table[:, final_id, 0, 0]

    return inside_complete_table, inside_incomplete_table, partition_score
Ejemplo n.º 3
0
def batch_parse(batch_scores):
    batch_size, sentence_length, _ = batch_scores.shape
    # CYK table
    complete_table = torch.zeros(
        (batch_size, sentence_length * sentence_length * 2), dtype=torch.float)
    incomplete_table = torch.zeros(
        (batch_size, sentence_length * sentence_length * 2), dtype=torch.float)
    # backtrack table
    complete_backtrack = -torch.ones(
        (batch_size, sentence_length * sentence_length * 2), dtype=torch.int)
    incomplete_backtrack = -torch.ones(
        (batch_size, sentence_length * sentence_length * 2), dtype=torch.int)
    # span index table, to avoid redundant iterations
    span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
        sentence_length, False)
    # initial basic complete spans
    for ii in basic_span:
        complete_table[:, ii] = 0.0

    for ij in ijss:
        (l, r, dir) = id_2_span[ij]
        num_ki = len(ikis[ij])
        ik_ci = complete_table[:, ikis[ij]].reshape(batch_size, num_ki)
        kj_ci = complete_table[:, kjis[ij]].reshape(batch_size, num_ki)
        # construct incomplete spans
        if dir == 0:
            span_i = ik_ci + kj_ci + batch_scores[:, r, l].reshape(
                batch_size, 1)
        else:
            span_i = ik_ci + kj_ci + batch_scores[:, l, r].reshape(
                batch_size, 1)

        incomplete_table[:, ij] = torch.max(span_i, dim=1)[0]
        max_idx = torch.max(span_i, dim=1)[1]
        incomplete_backtrack[:, ij] = max_idx

        num_kc = len(ikcs[ij])
        if dir == 0:
            ik_cc = complete_table[:, ikcs[ij]].reshape(batch_size, num_kc)
            kj_ic = incomplete_table[:, kjcs[ij]].reshape(batch_size, num_kc)
            span_c = ik_cc + kj_ic
        else:
            ik_ic = incomplete_table[:, ikcs[ij]].reshape(batch_size, num_kc)
            kj_cc = complete_table[:, kjcs[ij]].reshape(batch_size, num_kc)
            span_c = ik_ic + kj_cc
        complete_table[:, ij] = torch.max(span_c, dim=1)[0]
        max_idx = torch.max(span_c, dim=1)[1]
        complete_backtrack[:, ij] = max_idx

    heads = -torch.ones((batch_size, sentence_length), dtype=torch.long)
    root_id = span_2_id[(0, sentence_length - 1, 1)]
    for s in range(batch_size):
        batch_backtracking(incomplete_backtrack, complete_backtrack, root_id,
                           1, heads, ikcs, ikis, kjcs, kjis, id_2_span,
                           span_2_id, s)

    return heads
Ejemplo n.º 4
0
def diff(outside_table,inside_table,batch_size,sentence_length,tag_num,partition_score):

    counts = inside_table[1] + outside_table[1]
    pseudo_count = torch.DoubleTensor(batch_size, sentence_length, sentence_length, tag_num,
                                      tag_num)
    pseudo_count.fill_(0.0)
    span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length)

    for l in range(sentence_length):
        for r in range(sentence_length):
            for dir in range(2):
                span_id = span_2_id.get((l, r, dir))
            if span_id is not None:
                if dir == 0:
                    pseudo_count[:, r, l, :, :] = counts[:, span_id, :, :]
                else:
                    pseudo_count[:, l, r, :, :] = counts[:, span_id, :, :]
    mius = pseudo_count - partition_score.contiguous().view(batch_size, 1, 1, 1, 1)
    diff = torch.exp(mius)
    #if mask is not None:
        #diff = diff.masked_fill_(mask, 0.0)
    return diff
Ejemplo n.º 5
0
def batch_parse(batch_scores, batch_decision_score, valency_num, cvalency_num):
    batch_size, sentence_length, _, tag_num, _, _ = batch_scores.shape
    # CYK table
    complete_table = np.zeros(
        (batch_size, sentence_length * sentence_length * 2, tag_num,
         valency_num))
    incomplete_table = np.zeros(
        (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num,
         valency_num))
    complete_table.fill(-np.inf)
    incomplete_table.fill(-np.inf)
    # backtrack table
    complete_backtrack = -np.ones(
        (batch_size, sentence_length * sentence_length * 2, tag_num,
         valency_num),
        dtype=int)
    incomplete_backtrack = -np.ones(
        (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num,
         valency_num),
        dtype=int)
    # span index table, to avoid redundant iterations
    span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
        sentence_length, False)
    # initial basic complete spans
    for ii in basic_span:
        (i, i, dir) = id_2_span[ii]
        complete_table[:, ii, :, :] = batch_decision_score[:, i, :, dir, :, 0]
    for ij in ijss:
        (l, r, dir) = id_2_span[ij]
        num_ki = len(ikis[ij])
        ik_ci = complete_table[:,
                               ikis[ij], :, :].reshape(batch_size, num_ki,
                                                       tag_num, 1, valency_num)
        kj_ci = complete_table[:,
                               kjis[ij], :, :].reshape(batch_size, num_ki, 1,
                                                       tag_num, valency_num)
        # construct incomplete spans
        if dir == 0:
            span_i = ik_ci[:, :, :, :, 0].reshape(batch_size, num_ki, tag_num, 1, 1) \
                     + kj_ci[:, :, :, :, 1].reshape(batch_size, num_ki, 1, tag_num, 1) + \
                     batch_scores[:, r, l, :, :, :].swapaxes(1, 2).reshape(batch_size, 1, tag_num, tag_num,
                                                                           cvalency_num) \
                     + batch_decision_score[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num)
        else:
            span_i = ik_ci[:, :, :, :, 1].reshape(batch_size, num_ki, tag_num, 1, 1) \
                     + kj_ci[:, :, :, :, 0].reshape(batch_size, num_ki, 1, tag_num, 1) + \
                     batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) \
                     + batch_decision_score[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num)
        max = np.max(span_i, axis=1)

        incomplete_table[:, ij, :, :, :] = np.max(span_i, axis=1)
        max_idx = np.argmax(span_i, axis=1)
        incomplete_backtrack[:, ij, :, :, :] = max_idx
        # construct complete spans
        num_kc = len(ikcs[ij])
        if dir == 0:
            ik_cc = complete_table[:, ikcs[ij], :, :].reshape(
                batch_size, num_kc, tag_num, 1, valency_num)
            kj_ic = incomplete_table[:, kjcs[ij], :, :, :].reshape(
                batch_size, num_kc, tag_num, tag_num, valency_num)
            span_c = ik_cc[:, :, :, :, 0].reshape(batch_size, num_kc, tag_num,
                                                  1, 1) + kj_ic
            span_c = span_c.reshape(batch_size, num_kc * tag_num, tag_num,
                                    valency_num)
        else:
            ik_ic = incomplete_table[:, ikcs[ij], :, :, :].reshape(
                batch_size, num_kc, tag_num, tag_num, valency_num)
            kj_cc = complete_table[:, kjcs[ij], :, :].reshape(
                batch_size, num_kc, 1, tag_num, valency_num)
            span_c = ik_ic + kj_cc[:, :, :, :, 0].reshape(
                batch_size, num_kc, 1, tag_num, 1)
            span_c = span_c.swapaxes(2,
                                     3).reshape(batch_size, num_kc * tag_num,
                                                tag_num, valency_num)
        complete_table[:, ij, :, :] = np.max(span_c, axis=1)
        max_idx = np.argmax(span_c, axis=1)
        complete_backtrack[:, ij, :, :] = max_idx

    tags = np.zeros((batch_size, sentence_length)).astype(int)
    heads = -np.ones((batch_size, sentence_length))
    head_valences = np.zeros((batch_size, sentence_length))
    valences = np.zeros((batch_size, sentence_length, 2))
    root_id = span_2_id[(0, sentence_length - 1, 1)]
    for s in range(batch_size):
        batch_backtracking(incomplete_backtrack, complete_backtrack, root_id,
                           0, 0, 0, 1, tags, heads, head_valences, valences,
                           ikcs, ikis, kjcs, kjis, id_2_span, span_2_id,
                           tag_num, s)

    return (heads, tags, head_valences, valences)
Ejemplo n.º 6
0
def batch_outside(inside_complete_table, inside_incomplete_table, batch_scores,
                  batch_decision_scores, valency_num, cvalency_num):
    batch_size, sentence_length, _, tag_num, _, _ = batch_scores.shape
    outside_complete_table = np.zeros(
        (batch_size, sentence_length * sentence_length * 2, tag_num,
         valency_num))
    outside_incomplete_table = np.zeros(
        (batch_size, sentence_length * sentence_length * 2, tag_num, tag_num,
         valency_num))
    span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
        sentence_length, False)
    outside_complete_table.fill(-np.inf)
    outside_incomplete_table.fill(-np.inf)

    root_id = span_2_id.get((0, sentence_length - 1, 1))
    outside_complete_table[:, root_id, 0, 0] = 0.0

    complete_span_used_0 = set()
    complete_span_used_1 = set()
    incomplete_span_used = set()
    complete_span_used_0.add(root_id)

    for ij in reversed(ijss):
        (l, r, dir) = id_2_span[ij]
        # complete span consists of one incomplete span and one complete span
        num_kc = len(ikcs[ij])
        if dir == 0:
            outside_ij_cc = outside_complete_table[:, ij, :, :].reshape(
                batch_size, 1, 1, tag_num, valency_num)
            inside_kj_ic = inside_incomplete_table[:,
                                                   kjcs[ij], :, :, :].reshape(
                                                       batch_size, num_kc,
                                                       tag_num, tag_num,
                                                       valency_num)
            inside_ik_cc = inside_complete_table[:, ikcs[ij], :, :].reshape(
                batch_size, num_kc, tag_num, 1, valency_num)
            outside_ik_cc = (outside_ij_cc + inside_kj_ic).swapaxes(2, 3)
            # swap left-right position since right tags are to be indexed
            outside_kj_ic = outside_ij_cc + inside_ik_cc[:, :, :, :,
                                                         0].reshape(
                                                             batch_size,
                                                             num_kc, tag_num,
                                                             1, 1)
            for i in range(num_kc):
                ik = ikcs[ij][i]
                kj = kjcs[ij][i]
                outside_ik_cc_i = logsumexp(outside_ik_cc[:, i, :, :, :],
                                            axis=(1, 3))
                if ik in complete_span_used_0:
                    outside_complete_table[:, ik, :, 0] = np.logaddexp(
                        outside_complete_table[:, ik, :, 0], outside_ik_cc_i)
                else:
                    outside_complete_table[:, ik, :,
                                           0] = np.copy(outside_ik_cc_i)
                    complete_span_used_0.add(ik)

                if kj in incomplete_span_used:
                    outside_incomplete_table[:, kj, :, :, :] = np.logaddexp(
                        outside_incomplete_table[:, kj, :, :, :],
                        outside_kj_ic[:, i, :, :, :])
                else:
                    outside_incomplete_table[:, kj, :, :, :] = np.copy(
                        outside_kj_ic[:, i, :, :, :])
                    incomplete_span_used.add(kj)
        else:
            outside_ij_cc = outside_complete_table[:, ij, :, :].reshape(
                batch_size, 1, tag_num, 1, valency_num)
            inside_ik_ic = inside_incomplete_table[:,
                                                   ikcs[ij], :, :, :].reshape(
                                                       batch_size, num_kc,
                                                       tag_num, tag_num,
                                                       valency_num)
            inside_kj_cc = inside_complete_table[:, kjcs[ij], :, :].reshape(
                batch_size, num_kc, 1, tag_num, valency_num)
            outside_kj_cc = outside_ij_cc + inside_ik_ic
            outside_ik_ic = outside_ij_cc + inside_kj_cc[:, :, :, :,
                                                         0].reshape(
                                                             batch_size,
                                                             num_kc, 1,
                                                             tag_num, 1)
            for i in range(num_kc):
                kj = kjcs[ij][i]
                ik = ikcs[ij][i]
                outside_kj_cc_i = logsumexp(outside_kj_cc[:, i, :, :, :],
                                            axis=(1, 3))
                if kj in complete_span_used_0:
                    outside_complete_table[:, kj, :, 0] = np.logaddexp(
                        outside_complete_table[:, kj, :, 0], outside_kj_cc_i)
                else:
                    outside_complete_table[:, kj, :,
                                           0] = np.copy(outside_kj_cc_i)
                    complete_span_used_0.add(kj)

                if ik in incomplete_span_used:
                    outside_incomplete_table[:, ik, :, :, :] = np.logaddexp(
                        outside_incomplete_table[:, ik, :, :, :],
                        outside_ik_ic[:, i, :, :, :])
                else:
                    outside_incomplete_table[:, ik, :, :, :] = np.copy(
                        outside_ik_ic[:, i, :, :, :])
                    incomplete_span_used.add(ik)

        # incomplete span consists of two complete spans
        num_ki = len(ikis[ij])

        outside_ij_ii = outside_incomplete_table[:, ij, :, :, :].reshape(
            batch_size, 1, tag_num, tag_num, valency_num)
        inside_ik_ci = inside_complete_table[:, ikis[ij], :, :].reshape(
            batch_size, num_ki, tag_num, 1, valency_num)
        inside_kj_ci = inside_complete_table[:, kjis[ij], :].reshape(
            batch_size, num_ki, 1, tag_num, valency_num)

        if dir == 0:
            outside_ik_ci_0 = outside_ij_ii + inside_kj_ci[:, :, :, :, 1].reshape(batch_size, num_ki, 1, tag_num, 1) + \
                              batch_scores[:, r, l, :, :, :].swapaxes(1, 2). \
                                  reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \
                              batch_decision_scores[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num)

            outside_kj_ci_1 = outside_ij_ii + inside_ik_ci[:, :, :, :, 0].reshape(batch_size, num_ki, tag_num, 1, 1) + \
                              batch_scores[:, r, l, :, :, :].swapaxes(1, 2). \
                                  reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \
                              batch_decision_scores[:, r, :, dir, :, 1].reshape(batch_size, 1, 1, tag_num, valency_num)
        else:
            outside_ik_ci_1 = outside_ij_ii + inside_kj_ci[:, :, :, :, 0].reshape(batch_size, num_ki, 1, tag_num, 1) \
                              + batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \
                              batch_decision_scores[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num)
            outside_kj_ci_0 = outside_ij_ii + inside_ik_ci[:, :, :, :, 1].reshape(batch_size, num_ki, tag_num, 1, 1) + \
                              batch_scores[:, l, r, :, :, :].reshape(batch_size, 1, tag_num, tag_num, cvalency_num) + \
                              batch_decision_scores[:, l, :, dir, :, 1].reshape(batch_size, 1, tag_num, 1, valency_num)

        for i in range(num_ki):
            ik = ikis[ij][i]
            kj = kjis[ij][i]
            if dir == 0:
                outside_ik_ci_i_0 = logsumexp(outside_ik_ci_0[:, i, :, :, :],
                                              axis=(2, 3))
                outside_kj_ci_i_1 = logsumexp(outside_kj_ci_1[:, i, :, :, :],
                                              axis=(1, 3))
            else:
                outside_ik_ci_i_1 = logsumexp(outside_ik_ci_1[:, i, :, :, :],
                                              axis=(2, 3))
                outside_kj_ci_i_0 = logsumexp(outside_kj_ci_0[:, i, :, :, :],
                                              axis=(1, 3))
            if dir == 0:
                if ik in complete_span_used_0:
                    outside_complete_table[:, ik, :, 0] = np.logaddexp(
                        outside_complete_table[:, ik, :, 0], outside_ik_ci_i_0)
                else:
                    outside_complete_table[:, ik, :,
                                           0] = np.copy(outside_ik_ci_i_0)
                    complete_span_used_0.add(ik)

                if kj in complete_span_used_1:
                    outside_complete_table[:, kj, :, 1] = np.logaddexp(
                        outside_complete_table[:, kj, :, 1], outside_kj_ci_i_1)
                else:
                    outside_complete_table[:, kj, :,
                                           1] = np.copy(outside_kj_ci_i_1)
                    complete_span_used_1.add(kj)

            else:
                if ik in complete_span_used_1:
                    outside_complete_table[:, ik, :, 1] = np.logaddexp(
                        outside_complete_table[:, ik, :, 1], outside_ik_ci_i_1)
                else:
                    outside_complete_table[:, ik, :,
                                           1] = np.copy(outside_ik_ci_i_1)
                    complete_span_used_1.add(ik)

                if kj in complete_span_used_0:
                    outside_complete_table[:, kj, :, 0] = np.logaddexp(
                        outside_complete_table[:, kj, :, 0], outside_kj_ci_i_0)
                else:
                    outside_complete_table[:, kj, :,
                                           0] = np.copy(outside_kj_ci_i_0)
                    complete_span_used_0.add(kj)

    return outside_complete_table, outside_incomplete_table
Ejemplo n.º 7
0
    def update_pseudo_count(self, inside_incomplete_table, inside_complete_table, sentence_prob,
                            outside_incomplete_table, outside_complete_table, trans_counter, root_counter,
                            decision_counter, batch_pos, batch_sen, batch_lan):
        batch_likelihood = 0.0
        batch_size, sentence_length = batch_pos.shape
        span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length, False)
        for s in range(batch_size):
            pos_sentence = batch_pos[s]
            sentence_id = batch_sen[s]
            language_id = batch_lan[s]
            one_sentence_count = []
            one_sentence_decision_count = []
            for h in range(sentence_length):
                for m in range(sentence_length):
                    if m == 0:
                        continue
                    if h == m:
                        continue
                    if h > m:
                        dir = 0
                    else:
                        dir = 1
                    h_pos = pos_sentence[h]

                    m_pos = pos_sentence[m]
                    m_dec_pos = m_pos
                    if dir == 0:
                        span_id = span_2_id[(m, h, dir)]
                    else:
                        span_id = span_2_id[(h, m, dir)]
                    dep_count = inside_incomplete_table[s, span_id, :, :, :] + \
                                outside_incomplete_table[s, span_id, :, :, :] - sentence_prob[s]
                    if dir == 0:
                        dep_count = dep_count.swapaxes(1, 0)
                    if self.cvalency == 1:
                        if h == 0:
                            root_counter[m_pos] += np.sum(np.exp(dep_count))
                        else:
                            trans_counter[h_pos, m_pos, dir, 0] += np.sum(np.exp(dep_count))
                    else:
                        if h == 0:
                            root_counter[m_pos] += np.sum(np.exp(dep_count))
                        else:
                            trans_counter[h_pos, m_pos, dir, :] += np.exp(dep_count).reshape(self.cvalency)
                    if self.use_neural:
                        for v in range(self.cvalency):
                            count = np.exp(dep_count).reshape(self.cvalency)[v]
                            if not h == 0:
                                self.rule_samples.append(list([h_pos, m_pos, dir, v, sentence_id, language_id, count]))
                    if h > 0:
                        h_dec_pos = h_pos
                        decision_counter[h_dec_pos, dir, :, 1] += np.exp(dep_count).reshape(self.cvalency)
                        if self.use_neural:
                            reshaped_count = np.exp(dep_count).reshape(self.cvalency)
                            for v in range(self.dvalency):
                                count = reshaped_count[v]
                                if not self.unified_network:
                                    self.decision_samples.append(
                                        list([h_dec_pos, dir, v, sentence_id, language_id, 1, count]))
                                else:
                                    self.decision_samples.append(
                                        list([h_pos, dir, v, sentence_id, language_id, 1, count]))
            for m in range(1, sentence_length):
                m_pos = pos_sentence[m]
                m_dec_pos = m_pos
                for d in range(2):
                    m_span_id = span_2_id[(m, m, d)]
                    stop_count = inside_complete_table[s, m_span_id, :, :] + \
                                 outside_complete_table[s, m_span_id, :, :] - sentence_prob[s]
                    decision_counter[m_dec_pos, d, :, 0] += np.exp(stop_count).reshape(self.cvalency)
                    if self.use_neural:
                        for v in range(self.dvalency):
                            count = np.exp(stop_count).reshape(self.cvalency)[v]
                            if not self.unified_network:
                                self.decision_samples.append(
                                    list([m_dec_pos, d, v, sentence_id, language_id, 0, count]))
                            else:
                                self.decision_samples.append(list([m_pos, d, v, 0, sentence_id, language_id, count]))

            batch_likelihood += sentence_prob[s]
            self.sentence_counter[sentence_id] = one_sentence_count
            self.sentence_decision_counter[sentence_id] = one_sentence_decision_count
        return batch_likelihood
Ejemplo n.º 8
0
    def update_pseudo_count(self, inside_incomplete_table, inside_complete_table, sentence_prob,
                            outside_incomplete_table, outside_complete_table, trans_counter,
                            decision_counter, lex_counter, batch_pos, batch_words):
        batch_likelihood = 0.0
        batch_size, sentence_length = batch_pos.shape
        span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(sentence_length, False)
        for sen_id in range(batch_size):
            pos_sentence = batch_pos[sen_id]
            word_sentence = batch_words[sen_id]
            for h in range(sentence_length):
                for m in range(sentence_length):
                    if m == 0:
                        continue
                    if h == m:
                        continue
                    if h > m:
                        dir = 0
                    else:
                        dir = 1
                    h_pos = pos_sentence[h]

                    m_pos = pos_sentence[m]
                    m_dec_pos = self.to_decision[m_pos]
                    m_word = word_sentence[m]
                    if dir == 0:
                        span_id = span_2_id[(m, h, dir)]
                    else:
                        span_id = span_2_id[(h, m, dir)]
                    dep_count = inside_incomplete_table[sen_id, span_id, :, :, :] + \
                                outside_incomplete_table[sen_id, span_id, :, :, :] - sentence_prob[sen_id]
                    if dir == 0:
                        dep_count = dep_count.swapaxes(1, 0)
                    if self.cvalency == 1:
                        trans_counter[h_pos, m_pos, :, :, dir, 0] += np.sum(np.exp(dep_count), axis=2)
                    else:
                        trans_counter[h_pos, m_pos, :, :, dir, :] += np.exp(dep_count)
                    if self.use_neural:
                        for h_tag_id in range(self.tag_num):
                            for m_tag_id in range(self.tag_num):
                                for v in range(self.cvalency):
                                    count = np.exp(dep_count[h_tag_id, m_tag_id, v])
                                    self.rule_samples.append(list([h_pos, m_pos, h_tag_id, m_tag_id, dir, v, count]))
                    if h > 0:
                        h_dec_pos = self.to_decision[h_pos]
                        decision_counter[h_dec_pos, :, dir, :, 1] += np.sum(np.exp(dep_count), axis=1)
                        if self.use_neural:
                            summed_count = np.sum(np.exp(dep_count), axis=1)
                            for h_tag_id in range(self.tag_num):
                                for v in range(self.dvalency):
                                    count = summed_count[h_tag_id, v]
                                    if not self.unified_network:
                                        self.decision_samples.append(list([h_dec_pos, h_tag_id, dir, v, 1, count]))
                                    else:
                                        self.decision_samples.append(list([h_pos, h_tag_id, dir, v, 1, count]))
                    if self.use_lex:
                        lex_counter[m_pos, :, m_word] += np.sum(np.exp(dep_count), axis=(0, 2))
            for m in range(1, sentence_length):
                m_pos = pos_sentence[m]
                m_dec_pos = self.to_decision[m_pos]
                m_word = word_sentence[m]
                for d in range(2):
                    m_span_id = span_2_id[(m, m, d)]
                    stop_count = inside_complete_table[sen_id, m_span_id, :, :] + \
                                 outside_complete_table[sen_id, m_span_id, :, :] - sentence_prob[sen_id]
                    decision_counter[m_dec_pos, :, d, :, 0] += np.exp(stop_count)
                    if self.use_neural:
                        for m_tag_id in range(self.tag_num):
                            for v in range(self.dvalency):
                                count = np.exp(stop_count[m_tag_id, v])
                                if not self.unified_network:
                                    self.decision_samples.append(list([m_dec_pos, m_tag_id, d, v, 0, count]))
                                else:
                                    self.decision_samples.append(list([m_pos, m_tag_id, d, v, 0, count]))

            batch_likelihood += sentence_prob[sen_id]
        return batch_likelihood
Ejemplo n.º 9
0
 def update_pseudo_count(self, inside_incomplete_table,
                         inside_complete_table, sentence_prob,
                         outside_incomplete_table, outside_complete_table,
                         trans_counter, decision_counter, batch_pos,
                         batch_sen, batch_lan):
     batch_likelihood = 0.0
     en_like = 0.0
     batch_size, sentence_length = batch_pos.shape
     span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
         sentence_length, False)
     for s in range(batch_size):
         pos_sentence = batch_pos[s]
         sentence_id = batch_sen[s]
         lan_id = batch_lan[s]
         one_sentence_count = []
         one_sentence_decision_count = []
         for h in range(sentence_length):
             for m in range(sentence_length):
                 if m == 0:
                     continue
                 if h == m:
                     continue
                 if h > m:
                     dir = 0
                 else:
                     dir = 1
                 h_pos = pos_sentence[h]
                 m_pos = pos_sentence[m]
                 if dir == 0:
                     span_id = span_2_id[(m, h, dir)]
                 else:
                     span_id = span_2_id[(h, m, dir)]
                 # Pseudo count for one dependency arc
                 dep_count = inside_incomplete_table[s, span_id, :, :, :] + \
                             outside_incomplete_table[s, span_id, :, :, :] - sentence_prob[s]
                 if dir == 0:
                     dep_count = dep_count.swapaxes(1, 0)
                 if self.cvalency == 1:
                     trans_counter[h_pos, m_pos, dir, 0,
                                   lan_id] += np.sum(np.exp(dep_count))
                 else:
                     trans_counter[h_pos, m_pos, dir, :,
                                   lan_id] += np.exp(dep_count).reshape(
                                       self.dvalency)
                 # Add training samples for neural network
                 if self.use_neural:
                     for v in range(self.cvalency):
                         count = np.exp(dep_count).reshape(self.dvalency)[v]
                         self.rule_samples.append(
                             list([
                                 h_pos, m_pos, dir, v, sentence_id, lan_id,
                                 count
                             ]))
                 if h > 0:
                     # Add count for CONTINUE decision
                     decision_counter[h_pos, dir, :, 1,
                                      lan_id] += np.exp(dep_count).reshape(
                                          self.dvalency)
                     if self.use_neural:
                         reshaped_count = np.exp(dep_count).reshape(
                             self.dvalency)
                         for v in range(self.dvalency):
                             count = reshaped_count[v]
                             self.decision_samples.append(
                                 list([
                                     h_pos, dir, v, sentence_id, lan_id, 1,
                                     count
                                 ]))
         for m in range(1, sentence_length):
             m_pos = pos_sentence[m]
             for d in range(2):
                 m_span_id = span_2_id[(m, m, d)]
                 # Pseudo count for STOP decision
                 stop_count = inside_complete_table[s, m_span_id, :, :] + \
                              outside_complete_table[s, m_span_id, :, :] - sentence_prob[s]
                 decision_counter[m_pos, d, :, 0,
                                  lan_id] += np.exp(stop_count).reshape(
                                      self.dvalency)
                 if self.use_neural:
                     for v in range(self.dvalency):
                         count = np.exp(stop_count).reshape(
                             self.dvalency)[v]
                         self.decision_samples.append(
                             list([
                                 m_pos, d, v, sentence_id, lan_id, 0, count
                             ]))
         batch_likelihood += sentence_prob[s]
         if self.language_map[sentence_id] == 'en':
             en_like += sentence_prob[s]
         if self.sentence_predict:
             self.sentence_counter[sentence_id] = one_sentence_count
             self.sentence_decision_counter[
                 sentence_id] = one_sentence_decision_count
     return batch_likelihood, en_like
Ejemplo n.º 10
0
    def batch_inside(self, crf_scores):
        inside_complete_table = torch.DoubleTensor(
            self.batch_size, self.sentence_length * self.sentence_length * 2,
            self.tag_num)
        inside_incomplete_table = torch.DoubleTensor(
            self.batch_size, self.sentence_length * self.sentence_length * 2,
            self.tag_num, self.tag_num)
        if torch.cuda.is_available():
            inside_complete_table = inside_complete_table.cuda()
            inside_incomplete_table = inside_incomplete_table.cuda()
        span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
            self.sentence_length, False)

        inside_complete_table.fill_(LOGZERO)
        inside_incomplete_table.fill_(LOGZERO)

        for ii in basic_span:
            inside_complete_table[:, ii, :] = 0.0

        for ij in ijss:
            (l, r, dir) = id_2_span[ij]
            # two complete span to form an incomplete span
            num_ki = len(ikis[ij])
            inside_ik_ci = inside_complete_table[:, ikis[ij], :].contiguous(
            ).view(self.batch_size, num_ki, self.tag_num, 1)
            inside_kj_ci = inside_complete_table[:, kjis[ij], :].contiguous(
            ).view(self.batch_size, num_ki, 1, self.tag_num)
            if dir == 0:
                span_inside_i = inside_ik_ci + inside_kj_ci + crf_scores[:, r, l, :, :] \
                    .permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num)
                # swap head-child to left-right position
            else:
                span_inside_i = inside_ik_ci + inside_kj_ci + crf_scores[:, l, r, :, :].contiguous(
                ).view(self.batch_size, 1, self.tag_num, self.tag_num)
            inside_incomplete_table[:,
                                    ij, :, :] = utils.logsumexp(span_inside_i,
                                                                axis=1)

            # one complete span and one incomplete span to form bigger complete span
            num_kc = len(ikcs[ij])
            if dir == 0:
                inside_ik_cc = inside_complete_table[:,
                                                     ikcs[ij], :].contiguous(
                                                     ).view(
                                                         self.batch_size,
                                                         num_kc, self.tag_num,
                                                         1)
                inside_kj_ic = inside_incomplete_table[:, kjcs[
                    ij], :, :].contiguous().view(self.batch_size, num_kc,
                                                 self.tag_num, self.tag_num)
                span_inside_c = inside_ik_cc + inside_kj_ic
                span_inside_c = span_inside_c.contiguous().view(
                    self.batch_size, num_kc * self.tag_num, self.tag_num)
                inside_complete_table[:,
                                      ij, :] = utils.logsumexp(span_inside_c,
                                                               axis=1)
            else:
                inside_ik_ic = inside_incomplete_table[:, ikcs[
                    ij], :, :].contiguous().view(self.batch_size, num_kc,
                                                 self.tag_num, self.tag_num)
                inside_kj_cc = inside_complete_table[:,
                                                     kjcs[ij], :].contiguous(
                                                     ).view(
                                                         self.batch_size,
                                                         num_kc, 1,
                                                         self.tag_num)
                span_inside_c = inside_ik_ic + inside_kj_cc
                span_inside_c = span_inside_c.permute(
                    0, 1, 3, 2).contiguous().view(self.batch_size,
                                                  num_kc * self.tag_num,
                                                  self.tag_num)
                # swap the left-right position since the left tags are to be indexed
                inside_complete_table[:,
                                      ij, :] = utils.logsumexp(span_inside_c,
                                                               axis=1)

        final_id = span_2_id[(0, self.sentence_length - 1, 1)]
        partition_score = inside_complete_table[:, final_id, 0]

        return (inside_complete_table,
                inside_incomplete_table), partition_score
Ejemplo n.º 11
0
    def batch_outside(self, inside_table, crf_score):
        inside_complete_table = inside_table[0]
        inside_incomplete_table = inside_table[1]
        outside_complete_table = torch.DoubleTensor(
            self.batch_size, self.sentence_length * self.sentence_length * 2,
            self.tag_num)
        outside_incomplete_table = torch.DoubleTensor(
            self.batch_size, self.sentence_length * self.sentence_length * 2,
            self.tag_num, self.tag_num)
        if torch.cuda.is_available():
            outside_complete_table = outside_complete_table.cuda()
            outside_incomplete_table = outside_incomplete_table.cuda()
        span_2_id, id_2_span, ijss, ikcs, ikis, kjcs, kjis, basic_span = utils.constituent_index(
            self.sentence_length, False)
        outside_complete_table.fill_(LOGZERO)
        outside_incomplete_table.fill_(LOGZERO)

        root_id = span_2_id.get((0, self.sentence_length - 1, 1))
        outside_complete_table[:, root_id, 0] = 0.0

        complete_span_used = set()
        incomplete_span_used = set()
        complete_span_used.add(root_id)

        for ij in reversed(ijss):
            (l, r, dir) = id_2_span[ij]
            # complete span consists of one incomplete span and one complete span
            num_kc = len(ikcs[ij])
            if dir == 0:
                outside_ij_cc = outside_complete_table[:, ij, :].contiguous(
                ).view(self.batch_size, 1, 1, self.tag_num)
                inside_kj_ic = inside_incomplete_table[:, kjcs[
                    ij], :, :].contiguous().view(self.batch_size, num_kc,
                                                 self.tag_num, self.tag_num)
                inside_ik_cc = inside_complete_table[:,
                                                     ikcs[ij], :].contiguous(
                                                     ).view(
                                                         self.batch_size,
                                                         num_kc, self.tag_num,
                                                         1)
                outside_ik_cc = (outside_ij_cc + inside_kj_ic).permute(
                    0, 1, 3, 2)
                # swap left-right position since right tags are to be indexed
                outside_kj_ic = outside_ij_cc + inside_ik_cc
                for i in range(num_kc):
                    ik = ikcs[ij][i]
                    kj = kjcs[ij][i]
                    outside_ik_cc_i = utils.logsumexp(outside_ik_cc[:,
                                                                    i, :, :],
                                                      axis=1)
                    if ik in complete_span_used:
                        outside_complete_table[:, ik, :] = utils.logaddexp(
                            outside_complete_table[:, ik, :], outside_ik_cc_i)
                    else:
                        outside_complete_table[:,
                                               ik, :] = outside_ik_cc_i.clone(
                                               )
                        complete_span_used.add(ik)

                    if kj in incomplete_span_used:
                        outside_incomplete_table[:, kj, :, :] = utils.logaddexp(
                            outside_incomplete_table[:, kj, :, :],
                            outside_kj_ic[:, i, :, :])
                    else:
                        outside_incomplete_table[:,
                                                 kj, :, :] = outside_kj_ic[:,
                                                                           i, :, :]
                        incomplete_span_used.add(kj)
            else:
                outside_ij_cc = outside_complete_table[:, ij, :].contiguous(
                ).view(self.batch_size, 1, self.tag_num, 1)
                inside_ik_ic = inside_incomplete_table[:, ikcs[
                    ij], :, :].contiguous().view(self.batch_size, num_kc,
                                                 self.tag_num, self.tag_num)
                inside_kj_cc = inside_complete_table[:,
                                                     kjcs[ij], :].contiguous(
                                                     ).view(
                                                         self.batch_size,
                                                         num_kc, 1,
                                                         self.tag_num)
                outside_kj_cc = outside_ij_cc + inside_ik_ic
                outside_ik_ic = outside_ij_cc + inside_kj_cc
                for i in range(num_kc):
                    kj = kjcs[ij][i]
                    ik = ikcs[ij][i]
                    outside_kj_cc_i = utils.logsumexp(outside_kj_cc[:,
                                                                    i, :, :],
                                                      axis=1)
                    if kj in complete_span_used:
                        outside_complete_table[:, kj, :] = utils.logaddexp(
                            outside_complete_table[:, kj, :], outside_kj_cc_i)
                    else:
                        outside_complete_table[:,
                                               kj, :] = outside_kj_cc_i.clone(
                                               )
                        complete_span_used.add(kj)

                    if ik in incomplete_span_used:
                        outside_incomplete_table[:, ik, :, :] = utils.logaddexp(
                            outside_incomplete_table[:, ik, :, :],
                            outside_ik_ic[:, i, :, :])
                    else:
                        outside_incomplete_table[:,
                                                 ik, :, :] = outside_ik_ic[:,
                                                                           i, :, :]
                        incomplete_span_used.add(ik)

            # incomplete span consists of two complete spans
            num_ki = len(ikis[ij])

            outside_ij_ii = outside_incomplete_table[:, ij, :, :].contiguous(
            ).view(self.batch_size, 1, self.tag_num, self.tag_num)
            inside_ik_ci = inside_complete_table[:, ikis[ij], :].contiguous(
            ).view(self.batch_size, num_ki, self.tag_num, 1)
            inside_kj_ci = inside_complete_table[:, kjis[ij], :].contiguous(
            ).view(self.batch_size, num_ki, 1, self.tag_num)

            if dir == 0:
                outside_ik_ci = outside_ij_ii + inside_kj_ci + crf_score[:, r, l, :, :]. \
                    permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num)

                outside_kj_ci = outside_ij_ii + inside_ik_ci + crf_score[:, r, l, :, :]. \
                    permute(0, 2, 1).contiguous().view(self.batch_size, 1, self.tag_num, self.tag_num)
            else:
                outside_ik_ci = outside_ij_ii + inside_kj_ci + crf_score[:, l, r, :, :].contiguous(
                ).view(self.batch_size, 1, self.tag_num, self.tag_num)
                outside_kj_ci = outside_ij_ii + inside_ik_ci + crf_score[:, l, r, :, :].contiguous(
                ).view(self.batch_size, 1, self.tag_num, self.tag_num)

            for i in range(num_ki):
                ik = ikis[ij][i]
                kj = kjis[ij][i]

                outside_ik_ci_i = utils.logsumexp(outside_ik_ci[:, i, :, :],
                                                  axis=2)
                outside_kj_ci_i = utils.logsumexp(outside_kj_ci[:, i, :, :],
                                                  axis=1)
                if ik in complete_span_used:
                    outside_complete_table[:, ik, :] = utils.logaddexp(
                        outside_complete_table[:, ik, :], outside_ik_ci_i)
                else:
                    outside_complete_table[:, ik, :] = outside_ik_ci_i.clone()
                    complete_span_used.add(ik)
                if kj in complete_span_used:
                    outside_complete_table[:, kj, :] = utils.logaddexp(
                        outside_complete_table[:, kj, :], outside_kj_ci_i)
                else:
                    outside_complete_table[:, kj, :] = outside_kj_ci_i.clone()
                    complete_span_used.add(kj)

        return (outside_complete_table, outside_incomplete_table)