Exemple #1
0
    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = range(len(self.train_trees))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator + planner if needed
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(self.candgen_model)
            self.sampling_planner = SamplingPlanner({'language': self.language,
                                                     'selector': self.selector,
                                                     'candgen': self.candgen})
        if 'gen_cur_weights' in self.rival_gen_strategy:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({'candgen': self.candgen,
                                                   'language': self.language,
                                                   'selector': self.selector,
                                                   'ranker': self, })
Exemple #2
0
class BasePerceptronRanker(Ranker):

    def __init__(self, cfg):
        if not cfg:
            cfg = {}
        self.passes = cfg.get('passes', 5)
        self.alpha = cfg.get('alpha', 1)
        self.language = cfg.get('language', 'en')
        self.selector = cfg.get('selector', '')
        # initialize diagnostics
        self.lists_analyzer = None
        self.evaluator = None
        self.prune_feats = cfg.get('prune_feats', 1)
        self.rival_number = cfg.get('rival_number', 10)
        self.averaging = cfg.get('averaging', False)
        self.randomize = cfg.get('randomize', False)
        self.future_promise_weight = cfg.get('future_promise_weight', 1.0)
        self.future_promise_type = cfg.get('future_promise_type', 'expected_children')
        self.rival_gen_strategy = cfg.get('rival_gen_strategy', ['other_inst'])
        self.rival_gen_max_iter = cfg.get('rival_gen_max_iter', 50)
        self.rival_gen_max_defic_iter = cfg.get('rival_gen_max_defic_iter', 3)
        self.rival_gen_beam_size = cfg.get('rival_gen_beam_size')
        self.candgen_model = cfg.get('candgen_model')
        self.diffing_trees = cfg.get('diffing_trees', False)

    def score(self, cand_tree, da):
        """Score the given tree in the context of the given dialogue act.
        @param cand_tree: the candidate tree to be scored, as a TreeData object
        @param da: a DialogueAct object representing the input dialogue act
        """
        return self._score(self._extract_feats(cand_tree, da))

    def score_all(self, cand_trees, da):
        """Array version of the score() function"""
        return [self.score(cand_tree, da) for cand_tree in cand_trees]

    def _extract_feats(self, tree, da):
        raise NotImplementedError

    def train(self, das_file, ttree_file, data_portion=1.0):
        """Run training on the given training data."""
        self._init_training(das_file, ttree_file, data_portion)
        for iter_no in xrange(1, self.passes + 1):
            self.train_order = range(len(self.train_trees))
            if self.randomize:
                rnd.shuffle(self.train_order)
            log_info("Train order: " + str(self.train_order))
            self._training_pass(iter_no)
            if self.evaluator.tree_accuracy() == 1:  # if tree accuracy is 1, we won't learn anything anymore
                break
        # averaged perceptron – average the weights obtained after each pass
        if self.averaging is True:
            self.set_weights_iter_average()

    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = range(len(self.train_trees))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator + planner if needed
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(self.candgen_model)
            self.sampling_planner = SamplingPlanner({'language': self.language,
                                                     'selector': self.selector,
                                                     'candgen': self.candgen})
        if 'gen_cur_weights' in self.rival_gen_strategy:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({'candgen': self.candgen,
                                                   'language': self.language,
                                                   'selector': self.selector,
                                                   'ranker': self, })

    def _training_pass(self, pass_no):
        """Run one training pass, update weights (store them for possible averaging),
        and store diagnostic values."""

        pass_start_time = time.time()
        self.reset_diagnostics()
        self.update_weights_sum()

        log_debug('\n***\nTR %05d:' % pass_no)

        rgen_max_iter = self._get_num_iters(pass_no, self.rival_gen_max_iter)
        rgen_max_defic_iter = self._get_num_iters(pass_no, self.rival_gen_max_defic_iter)
        rgen_beam_size = self.rival_gen_beam_size

        for tree_no in self.train_order:
            # obtain some 'rival', alternative incorrect candidates
            gold_da, gold_tree, gold_feats = self.train_das[tree_no], self.train_trees[tree_no], self.train_feats[tree_no]

            for strategy in self.rival_gen_strategy:
                rival_das, rival_trees, rival_feats = self._get_rival_candidates(tree_no, strategy, rgen_max_iter,
                                                                                 rgen_max_defic_iter, rgen_beam_size)
                cands = [gold_feats] + rival_feats

                # score them along with the right one
                scores = [self._score(cand) for cand in cands]
                top_cand_idx = scores.index(max(scores))
                top_rival_idx = scores[1:].index(max(scores[1:]))
                top_rival_tree = rival_trees[top_rival_idx]
                top_rival_da = rival_das[top_rival_idx]

                # find the top-scoring generated tree, evaluate against gold t-tree
                # (disregarding whether it was selected as the best one)
                self.evaluator.append(TreeNode(gold_tree), TreeNode(top_rival_tree), scores[0], max(scores[1:]))

                # debug print: candidate trees
                log_debug('TTREE-NO: %04d, SEL_CAND: %04d, LEN: %02d' % (tree_no, top_cand_idx, len(cands)))
                log_debug('SENT: %s' % self.train_sents[tree_no])
                log_debug('ALL CAND TREES:')
                for ttree, score in zip([gold_tree] + rival_trees, scores):
                    log_debug("%12.5f" % score, "\t", ttree)

                # update weights if the system doesn't give the highest score to the right one
                if top_cand_idx != 0:
                    self._update_weights(gold_da, top_rival_da, gold_tree, top_rival_tree,
                                         gold_feats, cands[top_cand_idx])

        # store a copy of the current weights for averaging
        self.store_iter_weights()

        # debug print: current weights and pass accuracy
        log_debug(self._feat_val_str(), '\n***')
        log_debug('PASS ACCURACY: %.3f' % self.evaluator.tree_accuracy())

        # print and return statistics
        self._print_pass_stats(pass_no, datetime.timedelta(seconds=(time.time() - pass_start_time)))

    def diffing_trees_with_scores(self, da, good_tree, bad_tree):
        """For debugging purposes. Return a printout of diffing trees between the chosen candidate
        and the gold tree, along with scores."""
        good_sts, bad_sts = good_tree.diffing_trees(bad_tree, symmetric=False)
        comm_st = good_tree.get_common_subtree(bad_tree)
        ret = 'Common subtree: %.3f' % self.score(comm_st, da) + "\t" + unicode(comm_st) + "\n"
        ret += "Good subtrees:\n"
        for good_st in good_sts:
            ret += "%.3f" % self.score(good_st, da) + "\t" + unicode(good_st) + "\n"
        ret += "Bad subtrees:\n"
        for bad_st in bad_sts:
            ret += "%.3f" % self.score(bad_st, da) + "\t" + unicode(bad_st) + "\n"
        return ret

    def _get_num_iters(self, cur_pass_no, iter_setting):
        """Return the maximum number of iterations (total/deficit) given the current pass.
        Used to keep track of variable iteration number setting in configuration."""
        if isinstance(iter_setting, (list, tuple)):
            ret = 0
            for set_pass_no, set_iter_no in iter_setting:
                if set_pass_no > cur_pass_no:
                    break
                ret = set_iter_no
            return ret
        else:
            return iter_setting  # a single setting for all passes

    def _print_pass_stats(self, pass_no, pass_duration):
        """Print pass statistics from internal evaluator fields and given pass duration."""
        log_info('Pass %05d -- tree-level accuracy: %.4f' % (pass_no, self.evaluator.tree_accuracy()))
        log_info(' * Generated trees NODE scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1())
        log_info(' * Generated trees DEP  scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1(EvalTypes.DEP))
        log_info(' * Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %.4f' %
                 self.lists_analyzer.stats())
        log_info(' * Tree size stats:\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s' %
                 self.evaluator.tree_size_stats())
        log_info(' * Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s' %
                 self.evaluator.common_subtree_stats())
        log_info(' * Score stats\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s'
                 % self.evaluator.score_stats())
        log_info(' * Duration: %s' % str(pass_duration))

    def _feat_val_str(self, sep='\n', nonzero=False):
        """Return feature names and values for printing. To be overridden in base classes."""
        return ''

    def _get_rival_candidates(self, tree_no, strategy, max_iter, max_defic_iter, beam_size):
        """Generate some rival candidates for a DA and the correct (gold) tree,
        given the current rival generation strategy (self.rival_gen_strategy).

        TODO: checking for trees identical to the gold one slows down the process

        TODO: remove other generation strategies, remove support for multiple generated trees

        @param tree_no: the index of the current training data item (tree, DA)
        @rtype: tuple of two lists: one of TreeData's, one of arrays
        @return: an array of rival trees and an array of the corresponding features
        """
        da = self.train_das[tree_no]
        train_trees = self.train_trees

        rival_das, rival_trees, rival_feats = [], [], []

        if strategy != 'other_da':
            rival_das = [da] * self.rival_number

        # use current DA but change trees when computing features
        if strategy == 'other_inst':
            # use alternative indexes, avoid the correct one
            rival_idxs = map(lambda idx: len(train_trees) - 1 if idx == tree_no else idx,
                             rnd.sample(xrange(len(train_trees) - 1), self.rival_number))
            other_inst_trees = [train_trees[rival_idx] for rival_idx in rival_idxs]
            rival_trees.extend(other_inst_trees)
            rival_feats.extend([self._extract_feats(tree, da) for tree in other_inst_trees])

        # use the current gold tree but change DAs when computing features
        if strategy == 'other_da':
            rival_idxs = map(lambda idx: len(train_trees) - 1 if idx == tree_no else idx,
                             rnd.sample(xrange(len(train_trees) - 1), self.rival_number))
            other_inst_das = [self.train_das[rival_idx] for rival_idx in rival_idxs]
            rival_das.extend(other_inst_das)
            rival_trees.extend([self.train_trees[tree_no]] * self.rival_number)
            rival_feats.extend([self._extract_feats(self.train_trees[tree_no], da)
                                for da in other_inst_das])

        # candidates generated using the random planner (use the current DA)
        if strategy == 'random':
            random_trees = []
            while len(random_trees) < self.rival_number:
                tree = self.sampling_planner.generate_tree(da)
                if (tree != train_trees[tree_no]):  # don't generate trees identical to the gold one
                    random_trees.append(tree)
            rival_trees.extend(random_trees)
            rival_feats.extend([self._extract_feats(tree, da) for tree in random_trees])

        # candidates generated using the A*search planner, which uses this ranker with current
        # weights to guide the search, and the current DA as the input
        # TODO: use just one!, others are meaningless
        if strategy == 'gen_cur_weights':
            open_list, close_list = self.asearch_planner.run(da, max_iter, max_defic_iter, beam_size)
            self.lists_analyzer.append(train_trees[tree_no], open_list, close_list)
            gen_trees = []
            while close_list and len(gen_trees) < self.rival_number:
                tree = close_list.pop()[0]
                if tree != train_trees[tree_no]:
                    gen_trees.append(tree)
            rival_trees.extend(gen_trees[:self.rival_number])
            rival_feats.extend([self._extract_feats(tree, da)
                                for tree in gen_trees[:self.rival_number]])

        # return all resulting candidates
        return rival_das, rival_trees, rival_feats

    def get_weights(self):
        """Return the current ranker weights (parameters). To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights(self, w):
        """Set new ranker weights. To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_average(self, ws):
        """Set the weights as the average of the given array of weights (used in parallel training).
        To be overridden by derived classes."""
        raise NotImplementedError

    def store_iter_weights(self):
        """Remember the current weights to be used for averaging.
        To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_iter_average(self):
        """Set new weights as the average of all remembered weights. To be overridden by
        derived classes."""
        raise NotImplementedError

    def get_weights_sum(self):
        """Return weights size in order to weigh future promise against them."""
        raise NotImplementedError

    def update_weights_sum(self):
        """Update the current weights size for future promise weighining."""
        raise NotImplementedError

    def reset_diagnostics(self):
        """Reset the evaluation statistics (Evaluator and ASearchListsAnalyzer objects)."""
        self.evaluator = Evaluator()
        self.lists_analyzer = ASearchListsAnalyzer()

    def get_diagnostics(self):
        """Return the current evaluation statistics (a tuple of Evaluator and ASearchListsAnalyzer
        objects."""
        return self.evaluator, self.lists_analyzer

    def set_diagnostics_average(self, diags):
        """Given an array of evaluation statistics objects, average them an store in this ranker
        instance."""
        self.reset_diagnostics()
        for evaluator, lists_analyzer in diags:
            self.evaluator.merge(evaluator)
            self.lists_analyzer.merge(lists_analyzer)

    def get_future_promise(self, cand_tree):
        """Compute expected future promise for a tree."""
        w_sum = self.get_weights_sum()
        if self.future_promise_type == 'num_nodes':
            return w_sum * self.future_promise_weight * max(0, 10 - len(cand_tree))
        elif self.future_promise_type == 'norm_exp_children':
            return (self.candgen.get_future_promise(cand_tree) / len(cand_tree)) * w_sum * self.future_promise_weight
        elif self.future_promise_type == 'ands':
            prom = 0
            for idx, node in enumerate(cand_tree.nodes):
                if node.t_lemma == 'and':
                    num_kids = cand_tree.children_num(idx)
                    prom += max(0, 2 - num_kids)
            return prom * w_sum * self.future_promise_weight
        else:  # expected children (default)
            return self.candgen.get_future_promise(cand_tree) * w_sum * self.future_promise_weight

    def get_future_promise_all(self, cand_trees):
        """Array version of get_future_promise."""
        return [self.get_future_promise(cand_tree) for cand_tree in cand_trees]