Exemple #1
0
    def loss(self, tree):
        # if self.debug:
        #     avg_h = None
        # else:

        if self.pred_mode == 'td_avg_h':
            self.calcualte_avg(tree)
            avg_h = None
        elif self.pred_mode == 'avg_h':
            avg_h = self.collect_avg_hidden(tree)
        else:
            avg_h = None
        inside_score, inside_mu, inside_var = self.forward(tree, avg_h)

        energy, _, _ = self.gaussian_multi(inside_mu.unsqueeze(1),
                                           inside_var.unsqueeze(1),
                                           self.trans_root_mu.unsqueeze(2),
                                           self.trans_root_var.unsqueeze(2))
        # energy shape [num_label, comp, inside_comp]
        energy = logsumexp(energy + inside_score.unsqueeze(1) +
                           self.trans_root_weight.unsqueeze(2))

        golden_in_score, golden_in_mu, golden_in_var = self.golden_score(tree)

        golden_score, _, _ = self.gaussian_multi(
            golden_in_mu.unsqueeze(1), golden_in_var.unsqueeze(1),
            self.trans_root_mu[tree.label].unsqueeze(0),
            self.trans_root_var[tree.label].unsqueeze(0))
        # shape [inside_comp, comp]
        golden_score = logsumexp(
            golden_score + golden_in_score.unsqueeze(1) +
            self.trans_root_weight[tree.label].unsqueeze(0))
        loss = energy - golden_score
        return loss
    def maxrule_parsing(self, tree):
        num_label = self.num_label
        for child in tree.children:
            self.maxrule_parsing(child)

        if tree.is_leaf():
            tree.crf_cache['max_score'] = tree.crf_cache['state_weight']
        else:
            # calcualte expected count

            # get children's inside score and parent's outside score
            lc_in_score = tree.children[0].crf_cache['in_weight'].reshape(
                1, self.num_label, 1, 1, self.comp, 1)

            rc_in_score = tree.children[1].crf_cache['in_weight'].reshape(
                1, 1, self.num_label, 1, 1, self.comp)
            p_out_score = tree.crf_cache['out_weight'].reshape(
                self.num_label, 1, 1, self.comp, 1, 1)

            expected_count_fined = lc_in_score + rc_in_score + p_out_score + self.trans_matrix
            expected_count_fined = expected_count_fined.reshape(
                self.num_label, self.num_label, self.num_label, -1)
            expected_count = logsumexp(expected_count_fined, dim=3)
            expected_count = expected_count - logsumexp(
                expected_count.reshape(self.num_label, -1), dim=1).reshape(
                    self.num_label, 1, 1)
            max_label = torch.argmax(expected_count.reshape(num_label, -1),
                                     dim=1).cpu().numpy().astype(int)
            tree.crf_cache['expected_count'] = expected_count
            tree.crf_cache['max_labels'] = max_label
    def loss(self, tree):

        inside_score = self.forward(tree)

        energy = logsumexp(inside_score)

        golden_in_score = self.golden_score(tree)

        golden_score = logsumexp(golden_in_score)
        loss = energy - golden_score
        return loss
Exemple #4
0
    def golden_score(self, tree):
        children_in_scores = []
        for child in tree.children:
            children_in_scores.append(self.golden_score(child))

        if tree.is_leaf():
            return tree.lveg_cache['state_weight'][tree.label], \
                   tree.lveg_cache['state_mu'][tree.label], \
                   tree.lveg_cache['state_var'][tree.label]

        left_label = tree.get_child(0).label
        right_label = tree.get_child(1).label
        p_label = tree.label

        trans_weight = self.trans_weight[p_label, left_label, right_label]
        trans_mu_p = self.trans_mu_p[p_label, left_label, right_label]
        trans_mu_lc = self.trans_mu_lc[p_label, left_label, right_label]
        trans_mu_rc = self.trans_mu_rc[p_label, left_label, right_label]
        trans_var_p = self.trans_var_p[p_label, left_label, right_label]
        trans_var_lc = self.trans_var_lc[p_label, left_label, right_label]
        trans_var_rc = self.trans_var_rc[p_label, left_label, right_label]

        left_in_w, left_in_mu, left_in_var = children_in_scores[0]
        right_in_w, right_in_mu, right_in_var = children_in_scores[1]
        p_state_w, p_state_mu, p_state_var = tree.lveg_cache['state_weight'][tree.label], \
                                             tree.lveg_cache['state_mu'][tree.label], \
                                             tree.lveg_cache['state_var'][tree.label]
        # fixme
        left_score, _, _ = self.gaussian_multi(left_in_mu.unsqueeze(1),
                                               left_in_var.unsqueeze(1),
                                               trans_mu_lc.unsqueeze(0),
                                               trans_var_lc.unsqueeze(0))
        right_score, _, _ = self.gaussian_multi(right_in_mu.unsqueeze(1),
                                                right_in_var.unsqueeze(1),
                                                trans_mu_rc.unsqueeze(0),
                                                trans_var_rc.unsqueeze(0))
        # left score [comp, comp]
        left_score = logsumexp(left_score + left_in_w.unsqueeze(1), dim=0)
        right_score = logsumexp(right_score + right_in_w.unsqueeze(1), dim=0)

        trans_weight = trans_weight + left_score + right_score

        p_score, p_mu, p_var = self.gaussian_multi(p_state_mu.unsqueeze(1),
                                                   p_state_var.unsqueeze(1),
                                                   trans_mu_p.unsqueeze(0),
                                                   trans_var_p.unsqueeze(0))
        p_score = p_score + trans_weight.unsqueeze(0) + p_state_w.unsqueeze(1)

        return p_score.reshape(-1), p_mu.reshape(
            -1, self.gaussian_dim), p_var.reshape(-1, self.gaussian_dim)
Exemple #5
0
    def maxrule_parsing(self, tree):
        num_label = self.num_label
        for child in tree.children:
            self.maxrule_parsing(child)

        if tree.is_leaf():
            tree.lveg_cache['max_score'] = tree.lveg_cache['state_weight']
        else:
            # calcualte expected count

            # get children's inside score and parent's outside score
            lc_in_weight, lc_in_mu, lc_in_var = tree.children[0].lveg_cache['in_weight'], \
                                                tree.children[0].lveg_cache['in_mu'], \
                                                tree.children[0].lveg_cache['in_var']

            rc_in_weight, rc_in_mu, rc_in_var = tree.children[1].lveg_cache['in_weight'], \
                                                tree.children[1].lveg_cache['in_mu'], \
                                                tree.children[1].lveg_cache['in_var']

            p_out_weight, p_out_mu, p_out_var = tree.lveg_cache['out_weight'], \
                                                tree.lveg_cache['out_mu'], \
                                                tree.lveg_cache['out_var']

            left_score = self.gaussian_mutlisum(lc_in_weight, lc_in_mu,
                                                lc_in_var, None,
                                                self.trans_mu_lc,
                                                self.trans_var_lc,
                                                [1, num_label, 1, 1, -1], 4, 4)

            right_score = self.gaussian_mutlisum(
                rc_in_weight, rc_in_mu, rc_in_var, None, self.trans_mu_rc,
                self.trans_var_rc, [1, 1, num_label, 1, -1], 4, 4)

            p_score = self.gaussian_mutlisum(p_out_weight, p_out_mu, p_out_var,
                                             None, self.trans_mu_p,
                                             self.trans_var_p,
                                             [num_label, 1, 1, 1, -1], 4, 4)

            expected_count = logsumexp(left_score + right_score + p_score +
                                       self.trans_weight,
                                       dim=3)

            expected_count = expected_count - logsumexp(
                expected_count.reshape(num_label, -1), dim=1).reshape(
                    num_label, 1, 1)
            max_label = torch.argmax(expected_count.reshape(num_label, -1),
                                     dim=1).cpu().numpy().astype(int)
            tree.lveg_cache['expected_count'] = expected_count
            tree.lveg_cache['max_labels'] = max_label
    def forward(self, tree, avg_h):

        # calcualte inside score of a node
        children_inside_score = []
        for child in tree.children:
            children_inside_score.append(self.forward(child, avg_h))

        state_weight = self.calcualte_emission_gm(tree, avg_h)

        if tree.is_leaf():
            return state_weight
        else:
            assert len(children_inside_score) == 2
            # inside score shape
            # [num_label, comp, gaussian_dim]

            left_child_part = tree.children[0].crf_cache['in_weight'].reshape(
                1, self.num_label, 1, 1, self.comp, 1)

            right_child_part = tree.children[1].crf_cache['in_weight'].reshape(
                1, 1, self.num_label, 1, 1, self.comp)

            p_part = tree.crf_cache['state_weight'].reshape(
                self.num_label, 1, 1, self.comp, 1, 1)
            inside_score = p_part + left_child_part + right_child_part + self.trans_matrix
            inside_score = inside_score.permute(0, 3, 1, 2, 4, 5).reshape(
                self.num_label, self.comp, -1)
            inside_score = logsumexp(inside_score, dim=2)
            # shape [num_label, comp]
            tree.crf_cache['in_weight'] = inside_score
            return inside_score
Exemple #7
0
    def predict(self, tree):
        # alert should we calculate the outside score with the state score?

        tree.lveg_cache['out_weight'] = self.trans_root_weight
        tree.lveg_cache['out_mu'] = self.trans_root_mu
        tree.lveg_cache['out_var'] = self.trans_root_var
        # if self.debug:
        #     avg_h = None
        # else:
        if self.pred_mode == 'td_avg_h':
            self.calcualte_avg(tree)
            avg_h = None
        elif self.pred_mode == 'avg_h':
            avg_h = self.collect_avg_hidden(tree)
        else:
            avg_h = None
        inside_score, inside_mu, inside_var = self.forward(tree, avg_h)
        self.outside(tree)

        self.maxrule_parsing(tree)

        total_tree_score = self.gaussian_mutlisum(
            inside_score, inside_mu, inside_var, self.trans_root_weight,
            self.trans_root_mu, self.trans_root_var, [self.num_label, 1, -1],
            2, 1)
        total_tree_score = logsumexp(total_tree_score, dim=1)
        max_label = torch.argmax(total_tree_score)
        tree.lveg_cache['max_label'] = max_label.cpu().item()
        self.get_max_tree(tree)
        pred = self.collect_pred(tree, [])
        return pred
    def golden_score(self, tree):
        children_in_scores = []
        for child in tree.children:
            children_in_scores.append(self.golden_score(child))

        if tree.is_leaf():
            # shape [comp]
            return tree.crf_cache['state_weight'][tree.label]

        assert len(children_in_scores) == 2

        p_label = tree.label
        l_label = tree.children[0].label
        r_label = tree.children[1].label

        left_child_part = children_in_scores[0].reshape(1, self.comp, 1)
        right_child_part = children_in_scores[1].reshape(1, 1, self.comp)
        parent_part = tree.crf_cache['state_weight'][p_label].reshape(
            self.comp, 1, 1)
        golden_score = left_child_part + right_child_part + parent_part + \
                       self.trans_matrix[p_label, l_label, r_label]

        golden_score = logsumexp(golden_score.reshape(self.comp, -1), dim=1)

        return golden_score
    def loss(self, tree):

        if self.pred_mode == 'td_avg_h':
            self.calcualte_avg(tree)
            avg_h = None
        elif self.pred_mode == 'avg_h':
            avg_h = self.collect_avg_hidden(tree)
        else:
            avg_h = None
        inside_score = self.forward(tree, avg_h)

        energy = logsumexp(inside_score)

        golden_in_score = self.golden_score(tree)

        golden_score = logsumexp(golden_in_score)
        loss = energy - golden_score
        return loss
Exemple #10
0
    def gaussian_mutlisum(self, s_weight, s_mu, s_var, t_weight, t_mu, t_var,
                          s_shape, t_dim, sum_dim):
        # inside or outside score multiply with transition rule then sum
        # here we just need score
        multi_score, _, _ = self.inout_multi(s_weight, s_mu, s_var, t_weight,
                                             t_mu, t_var, s_shape, t_dim)
        summed_score = logsumexp(multi_score, dim=sum_dim)

        return summed_score
    def predict(self, tree):
        inside_score = self.forward(tree)
        self.outside(tree)

        self.maxrule_parsing(tree)

        total_tree_score = logsumexp(inside_score, dim=1)
        max_label = torch.argmax(total_tree_score)
        tree.crf_cache['max_label'] = max_label.cpu().item()
        self.get_max_tree(tree)
        pred = self.collect_pred(tree, [])
        return pred
Exemple #12
0
    def loss(self, tree):
        inside_score, inside_mu, inside_var = self.forward(tree)

        energy, _, _ = self.gaussian_multi(inside_mu.unsqueeze(1),
                                           inside_var.unsqueeze(1),
                                           self.trans_root_mu.unsqueeze(2),
                                           self.trans_root_var.unsqueeze(2))
        # energy shape [num_label, comp, inside_comp]
        energy = logsumexp(energy + inside_score.unsqueeze(1) +
                           self.trans_root_weight.unsqueeze(2))

        golden_in_score, golden_in_mu, golden_in_var = self.golden_score(tree)

        golden_score, _, _ = self.gaussian_multi(
            golden_in_mu.unsqueeze(1), golden_in_var.unsqueeze(1),
            self.trans_root_mu[tree.label].unsqueeze(0),
            self.trans_root_var[tree.label].unsqueeze(0))
        # shape [inside_comp, comp]
        golden_score = logsumexp(
            golden_score + golden_in_score.unsqueeze(1) +
            self.trans_root_weight[tree.label].unsqueeze(0))
        loss = energy - golden_score
        return loss
    def outside(self, tree):
        if tree.is_leaf():
            return

        # root part
        if tree.parent is None:
            tree.crf_cache['out_weight'] = tree.crf_cache['in_weight'].new(
                self.num_label, self.comp).fill_(0.0)

        # left part
        p_out_score = tree.crf_cache['out_weight'].reshape(
            self.num_label, 1, 1, self.comp, 1, 1)

        rc_in_score = tree.children[1].crf_cache['in_weight'].reshape(
            1, 1, self.num_label, 1, 1, self.comp)
        lc_in_score = tree.children[0].crf_cache['in_weight'].reshape(
            1, self.num_label, 1, 1, self.comp, 1)

        lc_out_score = rc_in_score + p_out_score + self.trans_matrix

        # shape [num_label, num_label, num_label, comp, comp, comp]
        lc_out_score = lc_out_score.permute(1, 4, 0, 2, 3,
                                            5).reshape(self.num_label,
                                                       self.comp, -1)
        lc_out_score = logsumexp(lc_out_score, dim=2)
        tree.children[0].crf_cache['out_weight'] = lc_out_score

        # right part
        rc_out_score = lc_in_score + p_out_score + self.trans_matrix
        rc_out_score = rc_out_score.permute(2, 5, 0, 1, 3,
                                            4).reshape(self.num_label,
                                                       self.comp, -1)
        rc_out_score = logsumexp(rc_out_score, dim=2)
        tree.children[1].crf_cache['out_weight'] = rc_out_score

        for child in tree.children:
            self.outside(child)
Exemple #14
0
    def forward(self, tree, avg_h):

        children_score = []
        for child in tree.children:
            children_score.append(self.forward(child, avg_h))
        emission_score = self.calcualte_emission_score(tree, avg_h)
        if tree.is_leaf():
            return emission_score
        else:
            assert len(children_score) == 2

            left_child_score = children_score[0].reshape(1, self.num_labels, 1)
            right_child_score = children_score[1].reshape(1, 1, self.num_labels)
            emission_score = emission_score.reshape([self.num_labels, 1, 1])
            inside_score = emission_score + right_child_score + left_child_score + self.trans_matrix
            inside_score = logsumexp(inside_score.reshape(self.num_labels, -1), dim=1)
            return inside_score
Exemple #15
0
    def forward(self, tree):

        children_score = []
        for child in tree.children:
            children_score.append(self.forward(child))
        emission_score = tree.crf_cache['be_hidden']
        if tree.is_leaf():
            return emission_score
        else:
            assert len(children_score) == 2

            left_child_score = children_score[0].reshape(1, self.num_labels, 1)
            right_child_score = children_score[1].reshape(1, 1, self.num_labels)
            emission_score = emission_score.reshape([self.num_labels, 1, 1])
            inside_score = emission_score + right_child_score + left_child_score + self.trans_matrix
            inside_score = logsumexp(inside_score.reshape(self.num_labels, -1), dim=1)
            return inside_score
    def predict(self, tree):
        if self.pred_mode == 'td_avg_h':
            self.calcualte_avg(tree)
            avg_h = None
        elif self.pred_mode == 'avg_h':
            avg_h = self.collect_avg_hidden(tree)
        else:
            avg_h = None
        inside_score = self.forward(tree, avg_h)
        self.outside(tree)

        self.maxrule_parsing(tree)

        total_tree_score = logsumexp(inside_score, dim=1)
        max_label = torch.argmax(total_tree_score)
        tree.crf_cache['max_label'] = max_label.cpu().item()
        self.get_max_tree(tree)
        pred = self.collect_pred(tree, [])
        return pred
Exemple #17
0
    def predict(self, tree):

        tree.lveg_cache['out_weight'] = self.trans_root_weight
        tree.lveg_cache['out_mu'] = self.trans_root_mu
        tree.lveg_cache['out_var'] = self.trans_root_var
        inside_score, inside_mu, inside_var = self.forward(tree)
        self.outside(tree)

        self.maxrule_parsing(tree)

        total_tree_score = self.gaussian_mutlisum(
            inside_score, inside_mu, inside_var, self.trans_root_weight,
            self.trans_root_mu, self.trans_root_var, [self.num_label, 1, -1],
            2, 1)
        total_tree_score = logsumexp(total_tree_score, dim=1)
        max_label = torch.argmax(total_tree_score)
        tree.lveg_cache['max_label'] = max_label.cpu().item()
        self.get_max_tree(tree)
        pred = self.collect_pred(tree, [])
        return pred
Exemple #18
0
    def forward(self, tree, avg_h):
        # just calculate the inside score
        children_score = []
        for child in tree.children:
            children_score.append(self.forward(child, avg_h))

        emission_score = self.calcualte_emission_score(tree, avg_h)
        if tree.is_leaf():
            return emission_score
        else:
            # alert binary tree so only has 2 child
            assert len(children_score) == 2
            children_score[0] = (children_score[0].unsqueeze(0) + self.trans_matrix).unsqueeze(2)
            children_score[1] = (children_score[1].unsqueeze(0) + self.trans_matrix).unsqueeze(1)
            emission_score = emission_score.reshape([self.num_labels, 1, 1])
            for child_score in children_score:
                emission_score = child_score + emission_score

        inside_score = logsumexp(emission_score.reshape(self.num_labels, -1), dim=1)

        return inside_score
Exemple #19
0
    def forward(self, tree):
        # just calculate the inside score
        children_score = []
        for child in tree.children:
            children_score.append(self.forward(child))

        emission_score = tree.crf_cache['be_hidden']

        if tree.is_leaf():
            return emission_score
        else:
            assert len(children_score) == 2
            children_score[0] = (children_score[0].unsqueeze(0) + self.trans_matrix).unsqueeze(2)
            children_score[1] = (children_score[1].unsqueeze(0) + self.trans_matrix).unsqueeze(1)
            emission_score = emission_score.reshape([self.num_labels, 1, 1])
            for child_score in children_score:
                emission_score = child_score + emission_score

        inside_score = logsumexp(emission_score.reshape(self.num_labels, -1), dim=1)

        return inside_score
Exemple #20
0
    def loss(self, tree):
        inside_score = self.forward(tree)
        energy = logsumexp(inside_score, dim=0)

        golden_score = self.golden_score(tree)
        return energy - golden_score