示例#1
0
文件: rank.py 项目: fooyou/tgen
    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = range(len(self.train_trees))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator + planner if needed
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(self.candgen_model)
            self.sampling_planner = SamplingPlanner({'language': self.language,
                                                     'selector': self.selector,
                                                     'candgen': self.candgen})
        if 'gen_cur_weights' in self.rival_gen_strategy:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({'candgen': self.candgen,
                                                   'language': self.language,
                                                   'selector': self.selector,
                                                   'ranker': self, })
示例#2
0
    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = range(len(self.train_trees))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(
                self.candgen_model)
#             self.sampling_planner = SamplingPlanner({'language': self.language,
#                                                      'selector': self.selector,
#                                                      'candgen': self.candgen})

# check if A*search planner is needed (i.e., any rival generation strategy requires it)
# and initialize it
        if isinstance(self.rival_gen_strategy[0], tuple):
            asearch_needed = any([
                s in ['gen_cur_weights', 'gen_update']
                for _, ss in self.rival_gen_strategy for s in ss
            ])
        else:
            asearch_needed = any([
                s in ['gen_cur_weights', 'gen_update']
                for s in self.rival_gen_strategy
            ])
        if asearch_needed:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({
                'candgen': self.candgen,
                'language': self.language,
                'selector': self.selector,
                'ranker': self,
            })
示例#3
0
文件: rank.py 项目: UFAL-DSG/tgen
    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = range(len(self.train_trees))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(self.candgen_model)
#             self.sampling_planner = SamplingPlanner({'language': self.language,
#                                                      'selector': self.selector,
#                                                      'candgen': self.candgen})

        # check if A*search planner is needed (i.e., any rival generation strategy requires it)
        # and initialize it
        if isinstance(self.rival_gen_strategy[0], tuple):
            asearch_needed = any([s in ['gen_cur_weights', 'gen_update']
                                  for _, ss in self.rival_gen_strategy
                                  for s in ss])
        else:
            asearch_needed = any([s in ['gen_cur_weights', 'gen_update']
                                  for s in self.rival_gen_strategy])
        if asearch_needed:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({'candgen': self.candgen,
                                                   'language': self.language,
                                                   'selector': self.selector,
                                                   'ranker': self, })
示例#4
0
文件: candgen.py 项目: UFAL-DSG/tgen
    def train(self, das_file, ttree_file):
        """``Training'' the generator (collect counts of DAIs and corresponding t-nodes).

        @param da_file: file with training DAs
        @param t_file: file with training t-trees (YAML or pickle)
        """
        # read training data
        log_info('Reading ' + ttree_file)
        ttrees = ttrees_from_doc(read_ttrees(ttree_file), self.language, self.selector)
        log_info('Reading ' + das_file)
        das = read_das(das_file)

        # collect counts
        log_info('Collecting counts')
        child_type_counts = {}
        child_num_counts = defaultdict(Counter)
        max_total_nodes = defaultdict(int)
        max_level_nodes = defaultdict(Counter)

        for ttree, da in zip(ttrees, das):
            # counts for formeme/lemma given DAI
            for dai in da:
                for tnode in ttree.get_descendants():
                    if dai not in child_type_counts:
                        child_type_counts[dai] = defaultdict(Counter)
                    parent_id = self._parent_node_id(tnode.parent)
                    child_id = (tnode.formeme, tnode.t_lemma, tnode > tnode.parent)
                    child_type_counts[dai][parent_id][child_id] += 1

            # counts for number of children
            for tnode in ttree.get_descendants(add_self=1):
                child_num_counts[self._parent_node_id(tnode)][len(tnode.get_children())] += 1

            # counts for max. number of nodes
            total_nodes = len(ttree.get_descendants(add_self=True))
            for dai in da:
                max_total_nodes[dai] = max((max_total_nodes[dai], total_nodes))
            level_nodes = defaultdict(int)
            for tnode in ttree.get_descendants(add_self=True):
                level_nodes[tnode.get_depth()] += 1
            for dai in da:
                for level in level_nodes.iterkeys():
                    max_level_nodes[dai][level] = max((max_level_nodes[dai][level],
                                                       level_nodes[level]))

        # prune counts
        if self.prune_threshold > 1:
            for dai, forms in child_type_counts.items():
                self._prune(forms)
                if not forms:
                    del child_type_counts[dai]
            self._prune(child_num_counts)

        # transform counts
        self.child_type_counts = child_type_counts
        self.child_num_cdfs = self.cdfs_from_counts(child_num_counts)
        self.max_children = {par_id: max(child_num_counts[par_id].keys())
                             for par_id in child_num_counts.keys()}
        self.exp_child_num = self.exp_from_cdfs(self.child_num_cdfs)

        if self.node_limits:
            self.node_limits = {dai: {'total': max_total}
                                for dai, max_total in max_total_nodes.iteritems()}
            for dai, max_levels in max_level_nodes.iteritems():
                self.node_limits[dai].update(max_levels)
        else:
            self.node_limits = None

        # Determine compatible DAIs for given lemmas/nodes (according to the compatibility setting)
        if self.compatible_dais_type:
            self.compatible_dais = self._compatibility_table(das, ttrees, lambda da: da.dais)

        # The same for compatible DA slots
        if self.compatible_slots:
            self.compatible_slots = self._compatibility_table(das, ttrees,
                                                              lambda da: [dai.slot for dai in da.dais])

        if self.classif:
            self.classif.train(das_file, ttree_file)
示例#5
0
    def train(self, das_file, ttree_file):
        """``Training'' the generator (collect counts of DAIs and corresponding t-nodes).

        @param da_file: file with training DAs
        @param t_file: file with training t-trees (YAML or pickle)
        """
        # read training data
        log_info('Reading ' + ttree_file)
        ttrees = ttrees_from_doc(read_ttrees(ttree_file), self.language,
                                 self.selector)
        log_info('Reading ' + das_file)
        das = read_das(das_file)

        # collect counts
        log_info('Collecting counts')
        child_type_counts = {}
        child_num_counts = defaultdict(Counter)
        max_total_nodes = defaultdict(int)
        max_level_nodes = defaultdict(Counter)

        for ttree, da in zip(ttrees, das):
            # counts for formeme/lemma given DAI
            for dai in da:
                for tnode in ttree.get_descendants():
                    if dai not in child_type_counts:
                        child_type_counts[dai] = defaultdict(Counter)
                    parent_id = self._parent_node_id(tnode.parent)
                    child_id = (tnode.formeme, tnode.t_lemma,
                                tnode > tnode.parent)
                    child_type_counts[dai][parent_id][child_id] += 1

            # counts for number of children
            for tnode in ttree.get_descendants(add_self=1):
                child_num_counts[self._parent_node_id(tnode)][len(
                    tnode.get_children())] += 1

            # counts for max. number of nodes
            total_nodes = len(ttree.get_descendants(add_self=True))
            for dai in da:
                max_total_nodes[dai] = max((max_total_nodes[dai], total_nodes))
            level_nodes = defaultdict(int)
            for tnode in ttree.get_descendants(add_self=True):
                level_nodes[tnode.get_depth()] += 1
            for dai in da:
                for level in level_nodes.iterkeys():
                    max_level_nodes[dai][level] = max(
                        (max_level_nodes[dai][level], level_nodes[level]))

        # prune counts
        if self.prune_threshold > 1:
            for dai, forms in child_type_counts.items():
                self._prune(forms)
                if not forms:
                    del child_type_counts[dai]
            self._prune(child_num_counts)

        # transform counts
        self.child_type_counts = child_type_counts
        self.child_num_cdfs = self.cdfs_from_counts(child_num_counts)
        self.max_children = {
            par_id: max(child_num_counts[par_id].keys())
            for par_id in child_num_counts.keys()
        }
        self.exp_child_num = self.exp_from_cdfs(self.child_num_cdfs)

        if self.node_limits:
            self.node_limits = {
                dai: {
                    'total': max_total
                }
                for dai, max_total in max_total_nodes.iteritems()
            }
            for dai, max_levels in max_level_nodes.iteritems():
                self.node_limits[dai].update(max_levels)
        else:
            self.node_limits = None

        # Determine compatible DAIs for given lemmas/nodes (according to the compatibility setting)
        if self.compatible_dais_type:
            self.compatible_dais = self._compatibility_table(
                das, ttrees, lambda da: da.dais)

        # The same for compatible DA slots
        if self.compatible_slots:
            self.compatible_slots = self._compatibility_table(
                das, ttrees, lambda da: [dai.name for dai in da.dais])

        if self.classif:
            self.classif.train(das_file, ttree_file)