예제 #1
0
파일: hntm.py 프로젝트: misonuma/tsntm
 def __init__(self, config):
     self.config = config
     
     self.t_variables = {}
     self.tree_idxs = config.tree_idxs
     self.topic_idxs = get_topic_idxs(self.tree_idxs)
     self.child_to_parent_idxs = get_child_to_parent_idxs(self.tree_idxs)
     self.tree_depth = get_depth(self.tree_idxs)
     self.n_depth = max(self.tree_depth.values())
     
     self.build()
예제 #2
0
    def __init__(self, config):
        self.config = config
        if 'cell' not in vars(config): config.cell = 'rnn'
        if 'prod' not in vars(config): config.prod = False

        self.t_variables = {}

        self.tree_idxs = config.tree_idxs
        self.topic_idxs = get_topic_idxs(self.tree_idxs)
        self.child_to_parent_idxs = get_child_to_parent_idxs(self.tree_idxs)
        self.tree_depth = get_depth(self.tree_idxs)
        self.n_depth = max(self.tree_depth.values())

        self.build()
예제 #3
0
    def __init__(self, config):
        self.config = config

        self.t_variables = {}
        self.tree_idxs = config.tree_idxs
        self.topic_idxs = get_topic_idxs(self.tree_idxs)
        self.child_to_parent_idxs = get_child_to_parent_idxs(self.tree_idxs)
        self.tree_depth = get_depth(self.tree_idxs)
        self.n_depth = max(self.tree_depth.values())

        self.level_nodes = get_level_nodes(self.tree_depth)
        self.leaf_parents = get_leaf_parents(self.tree_idxs, self.level_nodes,
                                             self.n_depth)

        self.node_level = {}
        for level, nodes in self.level_nodes.items():
            for node in nodes:
                self.node_level[node] = level
        self.build()
예제 #4
0
def update_config_tree(config, tree_idxs):
    config.tree_idxs = tree_idxs
    config.topic_idxs = get_topic_idxs(config.tree_idxs)
    config.n_topic = sum(
        [len(child_idxs) for child_idxs in config.tree_idxs.values()]) + 1

    config.child_to_parent_idxs = get_child_to_parent_idxs(config.tree_idxs)
    config.all_child_idxs = list(
        config.child_to_parent_idxs.keys())  # n_topic - 1

    config.tree_depth = get_depth(config.tree_idxs)
    config.depth_topic_idxs = defaultdict(list)
    for topic_idx, depth in config.tree_depth.items():
        config.depth_topic_idxs[depth].append(topic_idx)
    config.n_depth = max(config.tree_depth.values())

    if config.prior:
        depth_probs_topic_prior = {
            depth: 1. / config.n_depth / len(topic_idxs)
            for depth, topic_idxs in config.depth_topic_idxs.items()
        }
        config.probs_topic_prior = np.array([
            depth_probs_topic_prior[config.tree_depth[topic_idx]]
            for topic_idx in config.topic_idxs
        ],
                                            dtype=np.float32)

#     config.mask_tree_reg = get_mask_tree_reg(config.tree_idxs, config.all_child_idxs)
#     if config.mask_tree_type == 'tree':
#         config.mask_tree = get_mask_tree(config.tree_idxs, config.topic_idxs, config.depth_topic_idxs)
#     elif config.mask_tree_type == 'sibling':
#         config.mask_tree = get_mask_tree_sibling(config.tree_idxs, config.topic_idxs, config.depth_topic_idxs)
#     elif config.mask_tree_type == 'other':
#         config.mask_tree = get_mask_tree_other(config.tree_idxs, config.topic_idxs)
#     elif config.mask_tree_type == 'reverse':
#         config.mask_tree = get_mask_tree_reverse(config.tree_idxs, config.topic_idxs, config.depth_topic_idxs)
#     else:
#         raise
    return config
예제 #5
0
def treesum(args):
    tree_idxs, topic_sents, text, topk, threshold, truncate, max_summary_l, verbose = args

    docs = [doc.strip() for doc in text.split('</DOC>')]
    assert len(docs) == 8

    topic_idxs = get_topic_idxs(tree_idxs)
    summary_l_sents = {}

    topk_summary_sents_list = [[]]
    topk_summary_indices_list = [[]]
    topk_summary_rouge_list = [0.]
    candidate_topic_indices_list = [[0]]
    summary_sents = []

    def mean_rouge(docs, candidate_summary_sents):
        return np.mean([
            getattr(
                rouge_scorer.score(
                    target=doc,
                    prediction=' '.join(candidate_summary_sents))[rouge_name],
                rouge_obj) for doc in docs
        ])

    def rouge_precision(topk_summary_sents, topic_sent):
        return getattr(
            rouge_scorer.score(target=' '.join(topk_summary_sents),
                               prediction=topic_sent)[rouge_name], 'precision')

    for summary_l in range(1, max_summary_l + 1):
        if sum([
                len(candidate_topic_indices)
                for candidate_topic_indices in candidate_topic_indices_list
        ]) == 0:
            summary_l_sents[summary_l] = {
                'sents': sorted_summary_sents,
                'indices': sorted_summary_indices
            }
            continue
        # compute rouge for each candidate summary
        candidate_topic_sents_list = [[topic_sents[topic_index] for topic_index in candidate_topic_indices] \
                                                                              for candidate_topic_indices in candidate_topic_indices_list]
        candidate_summaries_rouge_list = [[mean_rouge(docs, candidate_summary_sents=topk_summary_sents + [topic_sent]) \
                                                       for topic_sent in candidate_topic_sents] \
                                                       for topk_summary_sents, candidate_topic_sents in zip(topk_summary_sents_list, candidate_topic_sents_list)]
        assert len(topk_summary_sents_list) == len(
            candidate_topic_sents_list
        ) == len(topk_summary_indices_list) == len(
            candidate_topic_indices_list) == len(
                topk_summary_rouge_list) == len(candidate_summaries_rouge_list)

        candidate_summaries_sents_indices_list = [[
            {'sents': topk_summary_sents + [topic_sent], 'indices': topk_summary_indices + [topic_index], 'topk_topic_index': topic_index, 'rouge': candidate_summary_rouge}
            if rouge_precision(topk_summary_sents, topic_sent) <= threshold
            else {'sents': topk_summary_sents, 'indices': topk_summary_indices, 'topk_topic_index': None, 'rouge': topk_summary_rouge}
            for topic_sent, topic_index, candidate_summary_rouge in zip(candidate_topic_sents, candidate_topic_indices, candidate_summaries_rouge)] \
        for topk_summary_sents, candidate_topic_sents, topk_summary_indices, candidate_topic_indices, topk_summary_rouge, candidate_summaries_rouge \
        in zip(topk_summary_sents_list, candidate_topic_sents_list, topk_summary_indices_list, candidate_topic_indices_list, topk_summary_rouge_list, candidate_summaries_rouge_list)]

        candidate_summaries_sents_list = [[candidate_summary_sents_indices['sents']\
                                                      for candidate_summary_sents_indices in candidate_summaries_sents_indices]\
                                                      for candidate_summaries_sents_indices in candidate_summaries_sents_indices_list]
        candidate_summaries_indices_list = [[candidate_summary_sents_indices['indices']\
                                                      for candidate_summary_sents_indices in candidate_summaries_sents_indices]\
                                                      for candidate_summaries_sents_indices in candidate_summaries_sents_indices_list]
        candidate_summaries_rouge_list = [[candidate_summary_sents_indices['rouge']\
                                                      for candidate_summary_sents_indices in candidate_summaries_sents_indices]\
                                                      for candidate_summaries_sents_indices in candidate_summaries_sents_indices_list]

        topk_topic_indices_list = [[candidate_summary_sents_indices['topk_topic_index']\
                                                      for candidate_summary_sents_indices in candidate_summaries_sents_indices]\
                                                      for candidate_summaries_sents_indices in candidate_summaries_sents_indices_list]
        candidate_rouges = [candidate_summary_rouge
                                                      for candidate_summaries_rouge in candidate_summaries_rouge_list\
                                                      for candidate_summary_rouge in candidate_summaries_rouge]

        # identify top k rouge of candidate summaries
        candidate_topic_args = np.array([[i, j] for i in range(len(candidate_summaries_sents_list)) \
                                                         for j in range(len(candidate_summaries_sents_list[i]))])
        assert len(candidate_topic_args) == len(candidate_rouges)
        topk_topic_args = candidate_topic_args[np.argsort(candidate_rouges)
                                               [::-1]]

        # identify top k rouge of candidate summaries
        topk_summary_sents_list = []
        topk_summary_indices_list = []
        topk_summary_rouge_list = []
        new_candidate_topic_indices_list = []
        for topk_topic_arg in topk_topic_args:
            topk_summary_indices = candidate_summaries_indices_list[
                topk_topic_arg[0]][topk_topic_arg[1]]
            if set(topk_summary_indices) in [
                    set(indices) for indices in topk_summary_indices_list
            ]:
                continue

            topk_summary_indices_list += [topk_summary_indices]
            topk_summary_sents_list += [
                candidate_summaries_sents_list[topk_topic_arg[0]][
                    topk_topic_arg[1]]
            ]
            topk_summary_rouge_list += [
                candidate_summaries_rouge_list[topk_topic_arg[0]][
                    topk_topic_arg[1]]
            ]
            topk_topic_index = topk_topic_indices_list[topk_topic_arg[0]][
                topk_topic_arg[1]]

            candidate_topic_indices = list(
                candidate_topic_indices_list[topk_topic_arg[0]])
            if topk_topic_index is not None:
                candidate_topic_indices.remove(topk_topic_index)
                topk_topic_idx = topic_idxs[topk_topic_index]
                if topk_topic_idx in tree_idxs:
                    child_indices = [
                        topic_idxs.index(child_idx)
                        for child_idx in tree_idxs[topk_topic_idx]
                    ]
                    new_candidate_topic_indices_list += [
                        candidate_topic_indices + child_indices
                    ]
                else:
                    new_candidate_topic_indices_list += [
                        candidate_topic_indices
                    ]
            else:
                new_candidate_topic_indices_list += [candidate_topic_indices]

            if len(topk_summary_indices_list) >= topk: break

#         candidate_topic_indices_list = new_candidate_topic_indices_list
        candidate_topic_indices_list = [[
            i for i in new_candidate_topic_indices
            if len(topic_sents[i].split()) > truncate
        ] for new_candidate_topic_indices in new_candidate_topic_indices_list]

        summary_sents = topk_summary_sents_list[0]
        summary_indices = topk_summary_indices_list[0]

        sorted_topic_idxs = get_sorted_topic_idxs(tree_idxs)
        sorted_topic_indices = [
            topic_idxs.index(topic_idx) for topic_idx in sorted_topic_idxs
        ]

        sorted_summary_indices = []
        sorted_summary_sents = []
        for sorted_topic_index in sorted_topic_indices:
            if sorted_topic_index in summary_indices:
                sorted_summary_indices.append(sorted_topic_index)
                sorted_summary_sent = summary_sents[summary_indices.index(
                    sorted_topic_index)]
                sorted_summary_sents.append(sorted_summary_sent)

        summary_l_sents[summary_l] = {
            'sents': sorted_summary_sents,
            'indices': sorted_summary_indices
        }

        if verbose:
            print(summary_l)
            print('sents:', summary_l_sents[summary_l]['sents'])
            print('idxs:', [
                topic_idxs[topic_index]
                for topic_index in summary_l_sents[summary_l]['indices']
            ])
            print('rouge:', np.sort(candidate_rouges)[::-1][0])
            print('topk idxs:', [[
                topic_idxs[summary_index] for summary_index in summary_indices
            ] for summary_indices in topk_summary_indices_list])
            print('topk rouges:', topk_summary_rouge_list)

    return summary_l_sents