Esempio n. 1
0
class BasePerceptronRanker(Ranker):

    def __init__(self, cfg):
        if not cfg:
            cfg = {}
        self.passes = cfg.get('passes', 5)
        self.alpha = cfg.get('alpha', 1)
        self.language = cfg.get('language', 'en')
        self.selector = cfg.get('selector', '')
        # initialize diagnostics
        self.lists_analyzer = None
        self.evaluator = None
        self.prune_feats = cfg.get('prune_feats', 1)
        self.rival_number = cfg.get('rival_number', 10)
        self.averaging = cfg.get('averaging', False)
        self.randomize = cfg.get('randomize', False)
        self.future_promise_weight = cfg.get('future_promise_weight', 1.0)
        self.future_promise_type = cfg.get('future_promise_type', 'expected_children')
        self.rival_gen_strategy = cfg.get('rival_gen_strategy', ['other_inst'])
        self.rival_gen_max_iter = cfg.get('rival_gen_max_iter', 50)
        self.rival_gen_max_defic_iter = cfg.get('rival_gen_max_defic_iter', 3)
        self.rival_gen_beam_size = cfg.get('rival_gen_beam_size')
        self.rival_gen_prune_size = cfg.get('rival_gen_prune_size')
        self.candgen_model = cfg.get('candgen_model')
        self.diffing_trees = cfg.get('diffing_trees', False)

    def score(self, cand_tree, da):
        """Score the given tree in the context of the given dialogue act.
        @param cand_tree: the candidate tree to be scored, as a TreeData object
        @param da: a DA object representing the input dialogue act
        """
        return self._score(self._extract_feats(cand_tree, da))

    def score_all(self, cand_trees, da):
        """Array version of the score() function"""
        return [self.score(cand_tree, da) for cand_tree in cand_trees]

    def _extract_feats(self, tree, da):
        raise NotImplementedError

    def train(self, das_file, ttree_file, data_portion=1.0):
        """Run training on the given training data."""
        self._init_training(das_file, ttree_file, data_portion)
        for iter_no in xrange(1, self.passes + 1):
            self.train_order = range(len(self.train_trees))
            if self.randomize:
                rnd.shuffle(self.train_order)
            log_info("Train order: " + str(self.train_order))
            self._training_pass(iter_no)
            if self.evaluator.tree_accuracy() == 1:  # if tree accuracy is 1, we won't learn anything anymore
                break
        # averaged perceptron – average the weights obtained after each pass
        if self.averaging is True:
            self.set_weights_iter_average()

    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = range(len(self.train_trees))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(self.candgen_model)
#             self.sampling_planner = SamplingPlanner({'language': self.language,
#                                                      'selector': self.selector,
#                                                      'candgen': self.candgen})

        # check if A*search planner is needed (i.e., any rival generation strategy requires it)
        # and initialize it
        if isinstance(self.rival_gen_strategy[0], tuple):
            asearch_needed = any([s in ['gen_cur_weights', 'gen_update']
                                  for _, ss in self.rival_gen_strategy
                                  for s in ss])
        else:
            asearch_needed = any([s in ['gen_cur_weights', 'gen_update']
                                  for s in self.rival_gen_strategy])
        if asearch_needed:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({'candgen': self.candgen,
                                                   'language': self.language,
                                                   'selector': self.selector,
                                                   'ranker': self, })

    def _training_pass(self, pass_no):
        """Run one training pass, update weights (store them for possible averaging),
        and store diagnostic values."""

        pass_start_time = time.time()
        self.reset_diagnostics()
        self.update_weights_sum()

        log_debug('\n***\nTR %05d:' % pass_no)

        rgen_max_iter = self._get_num_iters(pass_no, self.rival_gen_max_iter)
        rgen_max_defic_iter = self._get_num_iters(pass_no, self.rival_gen_max_defic_iter)
        rgen_beam_size = self.rival_gen_beam_size
        rgen_prune_size = self.rival_gen_prune_size
        rgen_strategy = self._get_rival_gen_strategy(pass_no)

        for tree_no in self.train_order:

            log_debug('TREE-NO: %d' % tree_no)
            log_debug('SENT: %s' % self.train_sents[tree_no])

            gold = Inst(da=self.train_das[tree_no],
                        tree=self.train_trees[tree_no],
                        score=self._score(self.train_feats[tree_no]),
                        feats=self.train_feats[tree_no])

            # obtain some 'rival', alternative incorrect candidates
            for strategy in rgen_strategy:

                # generate using current weights
                if strategy == 'gen_cur_weights':
                    gen = self._gen_cur_weights(gold, rgen_max_iter, rgen_max_defic_iter,
                                                rgen_prune_size, rgen_beam_size)

                # generate while trying to update weights
                elif strategy == 'gen_update':
                    gen = self._gen_update(gold, rgen_max_iter, rgen_max_defic_iter,
                                           rgen_prune_size, rgen_beam_size)

                # check against other possible candidates/combinations
                else:
                    gen = self._get_rival_candidates(gold, tree_no, strategy)

                # evaluate the top-scoring generated tree against gold t-tree
                # (disregarding whether it was selected as the best one)
                self.evaluator.append(TreeNode(gold.tree), TreeNode(gen.tree), gold.score, gen.score)

                # update weights if the system doesn't give the highest score to the gold standard tree
                if gold.score < gen.score:
                    self._update_weights(gold, gen)

        # store a copy of the current weights for averaging
        self.store_iter_weights()

        # debug print: current weights and pass accuracy
        log_debug(self._feat_val_str(), '\n***')
        log_debug('PASS ACCURACY: %.3f' % self.evaluator.tree_accuracy())

        # print and return statistics
        self._print_pass_stats(pass_no, datetime.timedelta(seconds=(time.time() - pass_start_time)))

    def diffing_trees_with_scores(self, da, good_tree, bad_tree):
        """For debugging purposes. Return a printout of diffing trees between the chosen candidate
        and the gold tree, along with scores."""
        good_sts, bad_sts = good_tree.diffing_trees(bad_tree, symmetric=False)
        comm_st = good_tree.get_common_subtree(bad_tree)
        ret = 'Common subtree: %.3f' % self.score(comm_st, da) + "\t" + unicode(comm_st) + "\n"
        ret += "Good subtrees:\n"
        for good_st in good_sts:
            ret += "%.3f" % self.score(good_st, da) + "\t" + unicode(good_st) + "\n"
        ret += "Bad subtrees:\n"
        for bad_st in bad_sts:
            ret += "%.3f" % self.score(bad_st, da) + "\t" + unicode(bad_st) + "\n"
        return ret

    def _get_num_iters(self, cur_pass_no, iter_setting):
        """Return the maximum number of iterations (total/deficit) given the current pass.
        Used to keep track of variable iteration number setting in configuration.

        @param cur_pass_no: the number of the current pass
        @param iter_setting: number of iteration setting (self.max_iter or self.max_defic_iter)
        """
        if isinstance(iter_setting, (list, tuple)):
            ret = 0
            for set_pass_no, set_iter_no in iter_setting:
                if set_pass_no > cur_pass_no:
                    break
                ret = set_iter_no
            return ret
        else:
            return iter_setting  # a single setting for all passes

    def _get_rival_gen_strategy(self, cur_pass_no):
        """Return the rival generation strategy/strategies for the current pass.
        Used to keep track of variable rival generation setting setting in configuration.

        @param cur_pass_no: the number of the current pass
        """
        if isinstance(self.rival_gen_strategy[0], tuple):
            ret = []
            for set_pass_no, strategies in self.rival_gen_strategy:
                if set_pass_no > cur_pass_no:
                    break
                ret = strategies
            return ret
        else:
            return self.rival_gen_strategy  # a single setting for all passes

    def _print_pass_stats(self, pass_no, pass_duration):
        """Print pass statistics from internal evaluator fields and given pass duration."""
        log_info('Pass %05d -- tree-level accuracy: %.4f' % (pass_no, self.evaluator.tree_accuracy()))
        log_info(' * Generated trees NODE scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1())
        log_info(' * Generated trees DEP  scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1(EvalTypes.DEP))
        log_info(' * Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %.4f' %
                 self.lists_analyzer.stats())
        log_info(' * Tree size stats:\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s' %
                 self.evaluator.size_stats())
        log_info(' * Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s' %
                 self.evaluator.common_substruct_stats())
        log_info(' * Score stats\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s'
                 % self.evaluator.score_stats())
        log_info(' * Duration: %s' % str(pass_duration))

    def _feat_val_str(self, sep='\n', nonzero=False):
        """Return feature names and values for printing. To be overridden in base classes."""
        return ''

    def _get_rival_candidates(self, gold, tree_no, strategy):
        """Generate some rival candidates for a DA and the correct (gold) tree,
        given a strategy; using other DAs for the correct tree, other trees for the correct
        DA, or random trees.

        NB: This has not been shown to be usable in practice; use _gen_cur_weights() instead.

        TODO: checking for trees identical to the gold one slows down the process

        @param tree_no: the index of the current training data item (tree, DA)
        @rtype: tuple of two lists: one of TreeData's, one of arrays
        @return: an array of rival trees and an array of the corresponding features
        """
        train_trees = self.train_trees

        rival_das, rival_trees, rival_feats = [], [], []

        if strategy != 'other_da':
            rival_das = [gold.da] * self.rival_number

        # use current DA but change trees when computing features
        if strategy == 'other_inst':
            # use alternative indexes, avoid the correct one
            rival_idxs = map(lambda idx: len(train_trees) - 1 if idx == tree_no else idx,
                             rnd.sample(xrange(len(train_trees) - 1), self.rival_number))
            other_inst_trees = [train_trees[rival_idx] for rival_idx in rival_idxs]
            rival_trees.extend(other_inst_trees)
            rival_feats.extend([self._extract_feats(tree, gold.da) for tree in other_inst_trees])

        # use the current gold tree but change DAs when computing features
        if strategy == 'other_da':
            rival_idxs = map(lambda idx: len(train_trees) - 1 if idx == tree_no else idx,
                             rnd.sample(xrange(len(train_trees) - 1), self.rival_number))
            other_inst_das = [self.train_das[rival_idx] for rival_idx in rival_idxs]
            rival_das.extend(other_inst_das)
            rival_trees.extend([self.train_trees[tree_no]] * self.rival_number)
            rival_feats.extend([self._extract_feats(self.train_trees[tree_no], da)
                                for da in other_inst_das])

#         # candidates generated using the random planner (use the current DA)
#         if strategy == 'random':
#             random_trees = []
#             while len(random_trees) < self.rival_number:
#                 tree = self.sampling_planner.generate_tree(da)
#                 if (tree != train_trees[tree_no]):  # don't generate trees identical to the gold one
#                     random_trees.append(tree)
#             rival_trees.extend(random_trees)
#             rival_feats.extend([self._extract_feats(tree, da) for tree in random_trees])

        # score them along with the right one
        rival_scores = [self._score(r) for r in rival_feats]
        top_rival_idx = rival_scores.index(max(rival_scores))
        gen = Inst(tree=rival_trees[top_rival_idx],
                   da=rival_das[top_rival_idx],
                   score=rival_scores[top_rival_idx],
                   feats=rival_feats[top_rival_idx])

        # debug print: candidate trees
        log_debug('#RIVALS: %02d' % len(rival_feats))
        log_debug('SEL: GOLD' if gold.score >= gen.score else ('SEL: RIVAL#%d' % top_rival_idx))
        log_debug('ALL CAND TREES:')
        for ttree, score in zip([gold.tree] + rival_trees, [gold.score] + rival_scores):
            log_debug("%12.5f" % score, "\t", ttree)

        return gen

    def _gen_cur_weights(self, gold, max_iter, max_defic_iter, prune_size, beam_size):
        """
        Get the best candidate generated using the A*search planner, which uses this ranker with current
        weights to guide the search, and the current DA as the input.

        @param gold: the gold-standard Inst holding the input DA for generation and the reference tree
        @param max_iter: maximum number of A*-search iterations to run
        @param max_defic_iter: maximum number of deficit A*-search iterations (stopping criterion)
        @param prune_size: beam size for open list pruning
        @param beam_size: beam size for candidate expansion (expand more per iteration if > 1)
        @return: The best generated tree that is different from the gold-standard tree
        @rtype: Inst
        """
        log_debug('GEN-CUR-WEIGHTS')
        # TODO make asearch_planner remember features (for last iteration, maybe)
        self.asearch_planner.run(gold.da, max_iter, max_defic_iter, prune_size, beam_size)
        return self.get_best_generated(gold)

    def get_best_generated(self, gold):
        """Return the best generated tree that is different from the gold-standard tree
        (to be used for updates, if it scores better). Also, keep track of logging and
        update analyzer lists.

        @param gold: the gold-standard Inst from which the generated tree must differ
        @rtype: Inst
        """
        self.lists_analyzer.append(gold.tree,
                                   self.asearch_planner.open_list,
                                   self.asearch_planner.close_list)

        gen_tree = gold.tree
        while self.asearch_planner.close_list and gen_tree == gold.tree:
            gen_tree, gen_score = self.asearch_planner.close_list.pop()

        # scores are negative on the close list – reverse the sign
        gen = Inst(tree=gen_tree, da=gold.da, score=-gen_score,
                   feats=self._extract_feats(gen_tree, gold.da))
        log_debug('SEL: GOLD' if gold.score >= gen.score else 'SEL: GEN')
        log_debug("GOLD:\t", "%12.5f" % gold.score, "\t", gold.tree)
        log_debug("GEN :\t", "%12.5f" % gen.score, "\t", gen.tree)
        return gen

    def _gen_update(self, gold, max_iter, max_defic_iter, prune_size, beam_size):
        """Try generating using the current weights, but update the weights after each
        iteration if the result is not going in the right direction (not a subtree of the
        gold-standard tree).

        @param gold: the gold-standard Inst holding the input DA for generation and the reference tree
        @param max_iter: maximum number of A*-search iterations to run
        @param max_defic_iter: maximum number of deficit A*-search iterations (stopping criterion)
        @param prune_size: beam size for open list pruning
        @param beam_size: beam size for candidate expansion (expand more per iteration if > 1)
        @return: The best generated tree that is different from the gold-standard tree
        @rtype: Inst
        """

        log_debug('GEN-UPDATE')
        self.asearch_planner.init_run(gold.da, max_iter, max_defic_iter, prune_size, beam_size)

        while not self.asearch_planner.check_finalize():
            # run one A*search iteration
            self.asearch_planner.run_iter()

            # stop if there's nothing on the open list
            if not self.asearch_planner.open_list:
                break

            # look if we are on the right track to the gold tree
            cur_top, score = self.asearch_planner.open_list.peek()
            csi, _ = gold.tree.common_subtree_idxs(cur_top)

            # if not, update
            if len(csi) != len(cur_top):

                feats = self._extract_feats(cur_top, gold.da)
                gen = Inst(tree=cur_top, da=gold.da, feats=feats, score=score)

                # for small wrong trees,
                # fake the current open list to only include a subtree of the gold tree
                # TODO fake it better, include more variants
                # update using a subtree of the gold tree
                if len(cur_top) < len(gold.tree):
                    diff = sorted(list(set(range(len(gold.tree))) - set(csi)),
                                  cmp=gold.tree._compare_node_depth)
                    gold_sub = gold.tree.get_subtree(csi + diff[0:len(cur_top) - len(gold.tree)])

                    self.asearch_planner.open_list.clear()
                    self.asearch_planner.open_list.push(gold_sub, score)
                    # TODO speed up by remembering the features in planner
                    feats = self._extract_feats(gold_sub, gold.da)
                    gold_sub = Inst(tree=gold_sub, da=gold.da, feats=feats, score=0)
                    self._update_weights(gold_sub, gen)

                # otherwise, update using the full gold tree
                else:
                    self._update_weights(gold, gen)

        return self.get_best_generated(gold)

    def get_weights(self):
        """Return the current ranker weights (parameters). To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights(self, w):
        """Set new ranker weights. To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_average(self, ws):
        """Set the weights as the average of the given array of weights (used in parallel training).
        To be overridden by derived classes."""
        raise NotImplementedError

    def store_iter_weights(self):
        """Remember the current weights to be used for averaging.
        To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_iter_average(self):
        """Set new weights as the average of all remembered weights. To be overridden by
        derived classes."""
        raise NotImplementedError

    def get_weights_sum(self):
        """Return weights size in order to weigh future promise against them."""
        raise NotImplementedError

    def update_weights_sum(self):
        """Update the current weights size for future promise weighining."""
        raise NotImplementedError

    def reset_diagnostics(self):
        """Reset the evaluation statistics (Evaluator and ASearchListsAnalyzer objects)."""
        self.evaluator = Evaluator()
        self.lists_analyzer = ASearchListsAnalyzer()

    def get_diagnostics(self):
        """Return the current evaluation statistics (a tuple of Evaluator and ASearchListsAnalyzer
        objects."""
        return self.evaluator, self.lists_analyzer

    def set_diagnostics_average(self, diags):
        """Given an array of evaluation statistics objects, average them an store in this ranker
        instance."""
        self.reset_diagnostics()
        for evaluator, lists_analyzer in diags:
            self.evaluator.merge(evaluator)
            self.lists_analyzer.merge(lists_analyzer)

    def get_future_promise(self, cand_tree):
        """Compute expected future promise for a tree."""
        w_sum = self.get_weights_sum()
        if self.future_promise_type == 'num_nodes':
            return w_sum * self.future_promise_weight * max(0, 10 - len(cand_tree))
        elif self.future_promise_type == 'norm_exp_children':
            return (self.candgen.get_future_promise(cand_tree) / len(cand_tree)) * w_sum * self.future_promise_weight
        elif self.future_promise_type == 'ands':
            prom = 0
            for idx, node in enumerate(cand_tree.nodes):
                if node.t_lemma == 'and':
                    num_kids = cand_tree.children_num(idx)
                    prom += max(0, 2 - num_kids)
            return prom * w_sum * self.future_promise_weight
        else:  # expected children (default)
            return self.candgen.get_future_promise(cand_tree) * w_sum * self.future_promise_weight

    def get_future_promise_all(self, cand_trees):
        """Array version of get_future_promise."""
        return [self.get_future_promise(cand_tree) for cand_tree in cand_trees]
Esempio n. 2
0
def asearch_gen(args):
    """A*search generation"""
    from pytreex.core.document import Document

    opts, files = getopt(args, 'e:d:w:c:s:')
    eval_file = None
    fname_ttrees_out = None
    cfg_file = None
    eval_selector = ''

    for opt, arg in opts:
        if opt == '-e':
            eval_file = arg
        elif opt == '-s':
            eval_selector = arg
        elif opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-w':
            fname_ttrees_out = arg
        elif opt == '-c':
            cfg_file = arg

    if len(files) != 3:
        sys.exit('Invalid arguments.\n' + __doc__)
    fname_cand_model, fname_rank_model, fname_da_test = files

    log_info('Initializing...')
    candgen = RandomCandidateGenerator.load_from_file(fname_cand_model)
    ranker = PerceptronRanker.load_from_file(fname_rank_model)
    cfg = Config(cfg_file) if cfg_file else {}
    cfg.update({'candgen': candgen, 'ranker': ranker})
    tgen = ASearchPlanner(cfg)

    log_info('Generating...')
    das = read_das(fname_da_test)

    if eval_file is None:
        gen_doc = Document()
    else:
        eval_doc = read_ttrees(eval_file)
        if eval_selector == tgen.selector:
            gen_doc = Document()
        else:
            gen_doc = eval_doc

    # generate and evaluate
    if eval_file is not None:
        # generate + analyze open&close lists
        lists_analyzer = ASearchListsAnalyzer()
        for num, (da, gold_tree) in enumerate(zip(
                das, trees_from_doc(eval_doc, tgen.language, eval_selector)),
                                              start=1):
            log_debug("\n\nTREE No. %03d" % num)
            gen_tree = tgen.generate_tree(da, gen_doc)
            lists_analyzer.append(gold_tree, tgen.open_list, tgen.close_list)
            if gen_tree != gold_tree:
                log_debug("\nDIFFING TREES:\n" +
                          tgen.ranker.diffing_trees_with_scores(
                              da, gold_tree, gen_tree) + "\n")

        log_info('Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %4f' %
                 lists_analyzer.stats())

        # evaluate the generated trees against golden trees
        eval_ttrees = ttrees_from_doc(eval_doc, tgen.language, eval_selector)
        gen_ttrees = ttrees_from_doc(gen_doc, tgen.language, tgen.selector)

        log_info('Evaluating...')
        evaler = Evaluator()
        for eval_bundle, eval_ttree, gen_ttree, da in zip(
                eval_doc.bundles, eval_ttrees, gen_ttrees, das):
            # add some stats about the tree directly into the output file
            add_bundle_text(
                eval_bundle, tgen.language, tgen.selector + 'Xscore',
                "P: %.4f R: %.4f F1: %.4f" %
                p_r_f1_from_counts(*corr_pred_gold(eval_ttree, gen_ttree)))

            # collect overall stats
            evaler.append(eval_ttree, gen_ttree,
                          ranker.score(TreeData.from_ttree(eval_ttree), da),
                          ranker.score(TreeData.from_ttree(gen_ttree), da))
        # print overall stats
        log_info("NODE precision: %.4f, Recall: %.4f, F1: %.4f" %
                 evaler.p_r_f1())
        log_info("DEP  precision: %.4f, Recall: %.4f, F1: %.4f" %
                 evaler.p_r_f1(EvalTypes.DEP))
        log_info("Tree size stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" %
                 evaler.size_stats())
        log_info("Score stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" %
                 evaler.score_stats())
        log_info(
            "Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s" %
            evaler.common_substruct_stats())
    # just generate
    else:
        for da in das:
            tgen.generate_tree(da, gen_doc)

    # write output
    if fname_ttrees_out is not None:
        log_info('Writing output...')
        write_ttrees(gen_doc, fname_ttrees_out)
Esempio n. 3
0
 def reset_diagnostics(self):
     """Reset the evaluation statistics (Evaluator and ASearchListsAnalyzer objects)."""
     self.evaluator = Evaluator()
     self.lists_analyzer = ASearchListsAnalyzer()
Esempio n. 4
0
class BasePerceptronRanker(Ranker):
    def __init__(self, cfg):
        if not cfg:
            cfg = {}
        self.passes = cfg.get('passes', 5)
        self.alpha = cfg.get('alpha', 1)
        self.language = cfg.get('language', 'en')
        self.selector = cfg.get('selector', '')
        # initialize diagnostics
        self.lists_analyzer = None
        self.evaluator = None
        self.prune_feats = cfg.get('prune_feats', 1)
        self.rival_number = cfg.get('rival_number', 10)
        self.averaging = cfg.get('averaging', False)
        self.randomize = cfg.get('randomize', False)
        self.future_promise_weight = cfg.get('future_promise_weight', 1.0)
        self.future_promise_type = cfg.get('future_promise_type',
                                           'expected_children')
        self.rival_gen_strategy = cfg.get('rival_gen_strategy', ['other_inst'])
        self.rival_gen_max_iter = cfg.get('rival_gen_max_iter', 50)
        self.rival_gen_max_defic_iter = cfg.get('rival_gen_max_defic_iter', 3)
        self.rival_gen_beam_size = cfg.get('rival_gen_beam_size')
        self.rival_gen_prune_size = cfg.get('rival_gen_prune_size')
        self.candgen_model = cfg.get('candgen_model')
        self.diffing_trees = cfg.get('diffing_trees', False)

    def score(self, cand_tree, da):
        """Score the given tree in the context of the given dialogue act.
        @param cand_tree: the candidate tree to be scored, as a TreeData object
        @param da: a DA object representing the input dialogue act
        """
        return self._score(self._extract_feats(cand_tree, da))

    def score_all(self, cand_trees, da):
        """Array version of the score() function"""
        return [self.score(cand_tree, da) for cand_tree in cand_trees]

    def _extract_feats(self, tree, da):
        raise NotImplementedError

    def train(self, das_file, ttree_file, data_portion=1.0):
        """Run training on the given training data."""
        self._init_training(das_file, ttree_file, data_portion)
        for iter_no in range(1, self.passes + 1):
            self.train_order = list(range(len(self.train_trees)))
            if self.randomize:
                rnd.shuffle(self.train_order)
            log_info("Train order: " + str(self.train_order))
            self._training_pass(iter_no)
            if self.evaluator.tree_accuracy(
            ) == 1:  # if tree accuracy is 1, we won't learn anything anymore
                break
        # averaged perceptron – average the weights obtained after each pass
        if self.averaging is True:
            self.set_weights_iter_average()

    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = list(range(len(self.train_trees)))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(
                self.candgen_model)
#             self.sampling_planner = SamplingPlanner({'language': self.language,
#                                                      'selector': self.selector,
#                                                      'candgen': self.candgen})

# check if A*search planner is needed (i.e., any rival generation strategy requires it)
# and initialize it
        if isinstance(self.rival_gen_strategy[0], tuple):
            asearch_needed = any([
                s in ['gen_cur_weights', 'gen_update']
                for _, ss in self.rival_gen_strategy for s in ss
            ])
        else:
            asearch_needed = any([
                s in ['gen_cur_weights', 'gen_update']
                for s in self.rival_gen_strategy
            ])
        if asearch_needed:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({
                'candgen': self.candgen,
                'language': self.language,
                'selector': self.selector,
                'ranker': self,
            })

    def _training_pass(self, pass_no):
        """Run one training pass, update weights (store them for possible averaging),
        and store diagnostic values."""

        pass_start_time = time.time()
        self.reset_diagnostics()
        self.update_weights_sum()

        log_debug('\n***\nTR %05d:' % pass_no)

        rgen_max_iter = self._get_num_iters(pass_no, self.rival_gen_max_iter)
        rgen_max_defic_iter = self._get_num_iters(
            pass_no, self.rival_gen_max_defic_iter)
        rgen_beam_size = self.rival_gen_beam_size
        rgen_prune_size = self.rival_gen_prune_size
        rgen_strategy = self._get_rival_gen_strategy(pass_no)

        for tree_no in self.train_order:

            log_debug('TREE-NO: %d' % tree_no)
            log_debug('SENT: %s' % self.train_sents[tree_no])

            gold = Inst(da=self.train_das[tree_no],
                        tree=self.train_trees[tree_no],
                        score=self._score(self.train_feats[tree_no]),
                        feats=self.train_feats[tree_no])

            # obtain some 'rival', alternative incorrect candidates
            for strategy in rgen_strategy:

                # generate using current weights
                if strategy == 'gen_cur_weights':
                    gen = self._gen_cur_weights(gold, rgen_max_iter,
                                                rgen_max_defic_iter,
                                                rgen_prune_size,
                                                rgen_beam_size)

                # generate while trying to update weights
                elif strategy == 'gen_update':
                    gen = self._gen_update(gold, rgen_max_iter,
                                           rgen_max_defic_iter,
                                           rgen_prune_size, rgen_beam_size)

                # check against other possible candidates/combinations
                else:
                    gen = self._get_rival_candidates(gold, tree_no, strategy)

                # evaluate the top-scoring generated tree against gold t-tree
                # (disregarding whether it was selected as the best one)
                self.evaluator.append(TreeNode(gold.tree), TreeNode(gen.tree),
                                      gold.score, gen.score)

                # update weights if the system doesn't give the highest score to the gold standard tree
                if gold.score < gen.score:
                    self._update_weights(gold, gen)

        # store a copy of the current weights for averaging
        self.store_iter_weights()

        # debug print: current weights and pass accuracy
        log_debug(self._feat_val_str(), '\n***')
        log_debug('PASS ACCURACY: %.3f' % self.evaluator.tree_accuracy())

        # print and return statistics
        self._print_pass_stats(
            pass_no,
            datetime.timedelta(seconds=(time.time() - pass_start_time)))

    def diffing_trees_with_scores(self, da, good_tree, bad_tree):
        """For debugging purposes. Return a printout of diffing trees between the chosen candidate
        and the gold tree, along with scores."""
        good_sts, bad_sts = good_tree.diffing_trees(bad_tree, symmetric=False)
        comm_st = good_tree.get_common_subtree(bad_tree)
        ret = 'Common subtree: %.3f' % self.score(
            comm_st, da) + "\t" + str(comm_st) + "\n"
        ret += "Good subtrees:\n"
        for good_st in good_sts:
            ret += "%.3f" % self.score(good_st,
                                       da) + "\t" + str(good_st) + "\n"
        ret += "Bad subtrees:\n"
        for bad_st in bad_sts:
            ret += "%.3f" % self.score(bad_st, da) + "\t" + str(bad_st) + "\n"
        return ret

    def _get_num_iters(self, cur_pass_no, iter_setting):
        """Return the maximum number of iterations (total/deficit) given the current pass.
        Used to keep track of variable iteration number setting in configuration.

        @param cur_pass_no: the number of the current pass
        @param iter_setting: number of iteration setting (self.max_iter or self.max_defic_iter)
        """
        if isinstance(iter_setting, (list, tuple)):
            ret = 0
            for set_pass_no, set_iter_no in iter_setting:
                if set_pass_no > cur_pass_no:
                    break
                ret = set_iter_no
            return ret
        else:
            return iter_setting  # a single setting for all passes

    def _get_rival_gen_strategy(self, cur_pass_no):
        """Return the rival generation strategy/strategies for the current pass.
        Used to keep track of variable rival generation setting setting in configuration.

        @param cur_pass_no: the number of the current pass
        """
        if isinstance(self.rival_gen_strategy[0], tuple):
            ret = []
            for set_pass_no, strategies in self.rival_gen_strategy:
                if set_pass_no > cur_pass_no:
                    break
                ret = strategies
            return ret
        else:
            return self.rival_gen_strategy  # a single setting for all passes

    def _print_pass_stats(self, pass_no, pass_duration):
        """Print pass statistics from internal evaluator fields and given pass duration."""
        log_info('Pass %05d -- tree-level accuracy: %.4f' %
                 (pass_no, self.evaluator.tree_accuracy()))
        log_info(' * Generated trees NODE scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1())
        log_info(' * Generated trees DEP  scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1(EvalTypes.DEP))
        log_info(' * Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %.4f' %
                 self.lists_analyzer.stats())
        log_info(
            ' * Tree size stats:\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s' %
            self.evaluator.size_stats())
        log_info(
            ' * Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s'
            % self.evaluator.common_substruct_stats())
        log_info(' * Score stats\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s' %
                 self.evaluator.score_stats())
        log_info(' * Duration: %s' % str(pass_duration))

    def _feat_val_str(self, sep='\n', nonzero=False):
        """Return feature names and values for printing. To be overridden in base classes."""
        return ''

    def _get_rival_candidates(self, gold, tree_no, strategy):
        """Generate some rival candidates for a DA and the correct (gold) tree,
        given a strategy; using other DAs for the correct tree, other trees for the correct
        DA, or random trees.

        NB: This has not been shown to be usable in practice; use _gen_cur_weights() instead.

        TODO: checking for trees identical to the gold one slows down the process

        @param tree_no: the index of the current training data item (tree, DA)
        @rtype: tuple of two lists: one of TreeData's, one of arrays
        @return: an array of rival trees and an array of the corresponding features
        """
        train_trees = self.train_trees

        rival_das, rival_trees, rival_feats = [], [], []

        if strategy != 'other_da':
            rival_das = [gold.da] * self.rival_number

        # use current DA but change trees when computing features
        if strategy == 'other_inst':
            # use alternative indexes, avoid the correct one
            rival_idxs = [
                len(train_trees) - 1 if idx == tree_no else idx
                for idx in rnd.sample(range(len(train_trees) -
                                            1), self.rival_number)
            ]
            other_inst_trees = [
                train_trees[rival_idx] for rival_idx in rival_idxs
            ]
            rival_trees.extend(other_inst_trees)
            rival_feats.extend([
                self._extract_feats(tree, gold.da) for tree in other_inst_trees
            ])

        # use the current gold tree but change DAs when computing features
        if strategy == 'other_da':
            rival_idxs = [
                len(train_trees) - 1 if idx == tree_no else idx
                for idx in rnd.sample(range(len(train_trees) -
                                            1), self.rival_number)
            ]
            other_inst_das = [
                self.train_das[rival_idx] for rival_idx in rival_idxs
            ]
            rival_das.extend(other_inst_das)
            rival_trees.extend([self.train_trees[tree_no]] * self.rival_number)
            rival_feats.extend([
                self._extract_feats(self.train_trees[tree_no], da)
                for da in other_inst_das
            ])


#         # candidates generated using the random planner (use the current DA)
#         if strategy == 'random':
#             random_trees = []
#             while len(random_trees) < self.rival_number:
#                 tree = self.sampling_planner.generate_tree(da)
#                 if (tree != train_trees[tree_no]):  # don't generate trees identical to the gold one
#                     random_trees.append(tree)
#             rival_trees.extend(random_trees)
#             rival_feats.extend([self._extract_feats(tree, da) for tree in random_trees])

# score them along with the right one
        rival_scores = [self._score(r) for r in rival_feats]
        top_rival_idx = rival_scores.index(max(rival_scores))
        gen = Inst(tree=rival_trees[top_rival_idx],
                   da=rival_das[top_rival_idx],
                   score=rival_scores[top_rival_idx],
                   feats=rival_feats[top_rival_idx])

        # debug print: candidate trees
        log_debug('#RIVALS: %02d' % len(rival_feats))
        log_debug('SEL: GOLD' if gold.score >= gen.score else (
            'SEL: RIVAL#%d' % top_rival_idx))
        log_debug('ALL CAND TREES:')
        for ttree, score in zip([gold.tree] + rival_trees,
                                [gold.score] + rival_scores):
            log_debug("%12.5f" % score, "\t", ttree)

        return gen

    def _gen_cur_weights(self, gold, max_iter, max_defic_iter, prune_size,
                         beam_size):
        """
        Get the best candidate generated using the A*search planner, which uses this ranker with current
        weights to guide the search, and the current DA as the input.

        @param gold: the gold-standard Inst holding the input DA for generation and the reference tree
        @param max_iter: maximum number of A*-search iterations to run
        @param max_defic_iter: maximum number of deficit A*-search iterations (stopping criterion)
        @param prune_size: beam size for open list pruning
        @param beam_size: beam size for candidate expansion (expand more per iteration if > 1)
        @return: The best generated tree that is different from the gold-standard tree
        @rtype: Inst
        """
        log_debug('GEN-CUR-WEIGHTS')
        # TODO make asearch_planner remember features (for last iteration, maybe)
        self.asearch_planner.run(gold.da, max_iter, max_defic_iter, prune_size,
                                 beam_size)
        return self.get_best_generated(gold)

    def get_best_generated(self, gold):
        """Return the best generated tree that is different from the gold-standard tree
        (to be used for updates, if it scores better). Also, keep track of logging and
        update analyzer lists.

        @param gold: the gold-standard Inst from which the generated tree must differ
        @rtype: Inst
        """
        self.lists_analyzer.append(gold.tree, self.asearch_planner.open_list,
                                   self.asearch_planner.close_list)

        gen_tree = gold.tree
        while self.asearch_planner.close_list and gen_tree == gold.tree:
            gen_tree, gen_score = self.asearch_planner.close_list.pop()

        # scores are negative on the close list – reverse the sign
        gen = Inst(tree=gen_tree,
                   da=gold.da,
                   score=-gen_score,
                   feats=self._extract_feats(gen_tree, gold.da))
        log_debug('SEL: GOLD' if gold.score >= gen.score else 'SEL: GEN')
        log_debug("GOLD:\t", "%12.5f" % gold.score, "\t", gold.tree)
        log_debug("GEN :\t", "%12.5f" % gen.score, "\t", gen.tree)
        return gen

    def _gen_update(self, gold, max_iter, max_defic_iter, prune_size,
                    beam_size):
        """Try generating using the current weights, but update the weights after each
        iteration if the result is not going in the right direction (not a subtree of the
        gold-standard tree).

        @param gold: the gold-standard Inst holding the input DA for generation and the reference tree
        @param max_iter: maximum number of A*-search iterations to run
        @param max_defic_iter: maximum number of deficit A*-search iterations (stopping criterion)
        @param prune_size: beam size for open list pruning
        @param beam_size: beam size for candidate expansion (expand more per iteration if > 1)
        @return: The best generated tree that is different from the gold-standard tree
        @rtype: Inst
        """

        log_debug('GEN-UPDATE')
        self.asearch_planner.init_run(gold.da, max_iter, max_defic_iter,
                                      prune_size, beam_size)

        while not self.asearch_planner.check_finalize():
            # run one A*search iteration
            self.asearch_planner.run_iter()

            # stop if there's nothing on the open list
            if not self.asearch_planner.open_list:
                break

            # look if we are on the right track to the gold tree
            cur_top, score = self.asearch_planner.open_list.peek()
            csi, _ = gold.tree.common_subtree_idxs(cur_top)

            # if not, update
            if len(csi) != len(cur_top):

                feats = self._extract_feats(cur_top, gold.da)
                gen = Inst(tree=cur_top, da=gold.da, feats=feats, score=score)

                # for small wrong trees,
                # fake the current open list to only include a subtree of the gold tree
                # TODO fake it better, include more variants
                # update using a subtree of the gold tree
                if len(cur_top) < len(gold.tree):
                    diff = sorted(list(set(range(len(gold.tree))) - set(csi)),
                                  cmp=gold.tree._compare_node_depth)
                    gold_sub = gold.tree.get_subtree(csi +
                                                     diff[0:len(cur_top) -
                                                          len(gold.tree)])

                    self.asearch_planner.open_list.clear()
                    self.asearch_planner.open_list.push(gold_sub, score)
                    # TODO speed up by remembering the features in planner
                    feats = self._extract_feats(gold_sub, gold.da)
                    gold_sub = Inst(tree=gold_sub,
                                    da=gold.da,
                                    feats=feats,
                                    score=0)
                    self._update_weights(gold_sub, gen)

                # otherwise, update using the full gold tree
                else:
                    self._update_weights(gold, gen)

        return self.get_best_generated(gold)

    def get_weights(self):
        """Return the current ranker weights (parameters). To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights(self, w):
        """Set new ranker weights. To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_average(self, ws):
        """Set the weights as the average of the given array of weights (used in parallel training).
        To be overridden by derived classes."""
        raise NotImplementedError

    def store_iter_weights(self):
        """Remember the current weights to be used for averaging.
        To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_iter_average(self):
        """Set new weights as the average of all remembered weights. To be overridden by
        derived classes."""
        raise NotImplementedError

    def get_weights_sum(self):
        """Return weights size in order to weigh future promise against them."""
        raise NotImplementedError

    def update_weights_sum(self):
        """Update the current weights size for future promise weighining."""
        raise NotImplementedError

    def reset_diagnostics(self):
        """Reset the evaluation statistics (Evaluator and ASearchListsAnalyzer objects)."""
        self.evaluator = Evaluator()
        self.lists_analyzer = ASearchListsAnalyzer()

    def get_diagnostics(self):
        """Return the current evaluation statistics (a tuple of Evaluator and ASearchListsAnalyzer
        objects."""
        return self.evaluator, self.lists_analyzer

    def set_diagnostics_average(self, diags):
        """Given an array of evaluation statistics objects, average them an store in this ranker
        instance."""
        self.reset_diagnostics()
        for evaluator, lists_analyzer in diags:
            self.evaluator.merge(evaluator)
            self.lists_analyzer.merge(lists_analyzer)

    def get_future_promise(self, cand_tree):
        """Compute expected future promise for a tree."""
        w_sum = self.get_weights_sum()
        if self.future_promise_type == 'num_nodes':
            return w_sum * self.future_promise_weight * max(
                0, 10 - len(cand_tree))
        elif self.future_promise_type == 'norm_exp_children':
            return (old_div(
                self.candgen.get_future_promise(cand_tree),
                len(cand_tree))) * w_sum * self.future_promise_weight
        elif self.future_promise_type == 'ands':
            prom = 0
            for idx, node in enumerate(cand_tree.nodes):
                if node.t_lemma == 'and':
                    num_kids = cand_tree.children_num(idx)
                    prom += max(0, 2 - num_kids)
            return prom * w_sum * self.future_promise_weight
        else:  # expected children (default)
            return self.candgen.get_future_promise(
                cand_tree) * w_sum * self.future_promise_weight

    def get_future_promise_all(self, cand_trees):
        """Array version of get_future_promise."""
        return [self.get_future_promise(cand_tree) for cand_tree in cand_trees]
Esempio n. 5
0
 def reset_diagnostics(self):
     """Reset the evaluation statistics (Evaluator and ASearchListsAnalyzer objects)."""
     self.evaluator = Evaluator()
     self.lists_analyzer = ASearchListsAnalyzer()
Esempio n. 6
0
def asearch_gen(args):
    """A*search generation"""
    from pytreex.core.document import Document

    opts, files = getopt(args, 'e:d:w:c:s:')
    eval_file = None
    fname_ttrees_out = None
    cfg_file = None
    eval_selector = ''

    for opt, arg in opts:
        if opt == '-e':
            eval_file = arg
        elif opt == '-s':
            eval_selector = arg
        elif opt == '-d':
            set_debug_stream(file_stream(arg, mode='w'))
        elif opt == '-w':
            fname_ttrees_out = arg
        elif opt == '-c':
            cfg_file = arg

    if len(files) != 3:
        sys.exit('Invalid arguments.\n' + __doc__)
    fname_cand_model, fname_rank_model, fname_da_test = files

    log_info('Initializing...')
    candgen = RandomCandidateGenerator.load_from_file(fname_cand_model)
    ranker = PerceptronRanker.load_from_file(fname_rank_model)
    cfg = Config(cfg_file) if cfg_file else {}
    cfg.update({'candgen': candgen, 'ranker': ranker})
    tgen = ASearchPlanner(cfg)

    log_info('Generating...')
    das = read_das(fname_da_test)

    if eval_file is None:
        gen_doc = Document()
    else:
        eval_doc = read_ttrees(eval_file)
        if eval_selector == tgen.selector:
            gen_doc = Document()
        else:
            gen_doc = eval_doc

    # generate and evaluate
    if eval_file is not None:
        # generate + analyze open&close lists
        lists_analyzer = ASearchListsAnalyzer()
        for num, (da, gold_tree) in enumerate(zip(das,
                                                  trees_from_doc(eval_doc, tgen.language, eval_selector)),
                                              start=1):
            log_debug("\n\nTREE No. %03d" % num)
            gen_tree = tgen.generate_tree(da, gen_doc)
            lists_analyzer.append(gold_tree, tgen.open_list, tgen.close_list)
            if gen_tree != gold_tree:
                log_debug("\nDIFFING TREES:\n" + tgen.ranker.diffing_trees_with_scores(da, gold_tree, gen_tree) + "\n")

        log_info('Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %4f' % lists_analyzer.stats())

        # evaluate the generated trees against golden trees
        eval_ttrees = ttrees_from_doc(eval_doc, tgen.language, eval_selector)
        gen_ttrees = ttrees_from_doc(gen_doc, tgen.language, tgen.selector)

        log_info('Evaluating...')
        evaler = Evaluator()
        for eval_bundle, eval_ttree, gen_ttree, da in zip(eval_doc.bundles, eval_ttrees, gen_ttrees, das):
            # add some stats about the tree directly into the output file
            add_bundle_text(eval_bundle, tgen.language, tgen.selector + 'Xscore',
                            "P: %.4f R: %.4f F1: %.4f" % p_r_f1_from_counts(*corr_pred_gold(eval_ttree, gen_ttree)))

            # collect overall stats
            evaler.append(eval_ttree,
                          gen_ttree,
                          ranker.score(TreeData.from_ttree(eval_ttree), da),
                          ranker.score(TreeData.from_ttree(gen_ttree), da))
        # print overall stats
        log_info("NODE precision: %.4f, Recall: %.4f, F1: %.4f" % evaler.p_r_f1())
        log_info("DEP  precision: %.4f, Recall: %.4f, F1: %.4f" % evaler.p_r_f1(EvalTypes.DEP))
        log_info("Tree size stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" % evaler.size_stats())
        log_info("Score stats:\n * GOLD %s\n * PRED %s\n * DIFF %s" % evaler.score_stats())
        log_info("Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s" %
                 evaler.common_substruct_stats())
    # just generate
    else:
        for da in das:
            tgen.generate_tree(da, gen_doc)

    # write output
    if fname_ttrees_out is not None:
        log_info('Writing output...')
        write_ttrees(gen_doc, fname_ttrees_out)
Esempio n. 7
0
File: rank.py Progetto: fooyou/tgen
class BasePerceptronRanker(Ranker):

    def __init__(self, cfg):
        if not cfg:
            cfg = {}
        self.passes = cfg.get('passes', 5)
        self.alpha = cfg.get('alpha', 1)
        self.language = cfg.get('language', 'en')
        self.selector = cfg.get('selector', '')
        # initialize diagnostics
        self.lists_analyzer = None
        self.evaluator = None
        self.prune_feats = cfg.get('prune_feats', 1)
        self.rival_number = cfg.get('rival_number', 10)
        self.averaging = cfg.get('averaging', False)
        self.randomize = cfg.get('randomize', False)
        self.future_promise_weight = cfg.get('future_promise_weight', 1.0)
        self.future_promise_type = cfg.get('future_promise_type', 'expected_children')
        self.rival_gen_strategy = cfg.get('rival_gen_strategy', ['other_inst'])
        self.rival_gen_max_iter = cfg.get('rival_gen_max_iter', 50)
        self.rival_gen_max_defic_iter = cfg.get('rival_gen_max_defic_iter', 3)
        self.rival_gen_beam_size = cfg.get('rival_gen_beam_size')
        self.candgen_model = cfg.get('candgen_model')
        self.diffing_trees = cfg.get('diffing_trees', False)

    def score(self, cand_tree, da):
        """Score the given tree in the context of the given dialogue act.
        @param cand_tree: the candidate tree to be scored, as a TreeData object
        @param da: a DialogueAct object representing the input dialogue act
        """
        return self._score(self._extract_feats(cand_tree, da))

    def score_all(self, cand_trees, da):
        """Array version of the score() function"""
        return [self.score(cand_tree, da) for cand_tree in cand_trees]

    def _extract_feats(self, tree, da):
        raise NotImplementedError

    def train(self, das_file, ttree_file, data_portion=1.0):
        """Run training on the given training data."""
        self._init_training(das_file, ttree_file, data_portion)
        for iter_no in xrange(1, self.passes + 1):
            self.train_order = range(len(self.train_trees))
            if self.randomize:
                rnd.shuffle(self.train_order)
            log_info("Train order: " + str(self.train_order))
            self._training_pass(iter_no)
            if self.evaluator.tree_accuracy() == 1:  # if tree accuracy is 1, we won't learn anything anymore
                break
        # averaged perceptron – average the weights obtained after each pass
        if self.averaging is True:
            self.set_weights_iter_average()

    def _init_training(self, das_file, ttree_file, data_portion):
        """Initialize training (read input data, fix size, initialize candidate generator
        and planner)"""
        # read input
        log_info('Reading DAs from ' + das_file + '...')
        das = read_das(das_file)
        log_info('Reading t-trees from ' + ttree_file + '...')
        ttree_doc = read_ttrees(ttree_file)
        sents = sentences_from_doc(ttree_doc, self.language, self.selector)
        trees = trees_from_doc(ttree_doc, self.language, self.selector)

        # make training data smaller if necessary
        train_size = int(round(data_portion * len(trees)))
        self.train_trees = trees[:train_size]
        self.train_das = das[:train_size]
        self.train_sents = sents[:train_size]
        self.train_order = range(len(self.train_trees))
        log_info('Using %d training instances.' % train_size)

        # initialize candidate generator + planner if needed
        if self.candgen_model is not None:
            self.candgen = RandomCandidateGenerator.load_from_file(self.candgen_model)
            self.sampling_planner = SamplingPlanner({'language': self.language,
                                                     'selector': self.selector,
                                                     'candgen': self.candgen})
        if 'gen_cur_weights' in self.rival_gen_strategy:
            assert self.candgen is not None
            self.asearch_planner = ASearchPlanner({'candgen': self.candgen,
                                                   'language': self.language,
                                                   'selector': self.selector,
                                                   'ranker': self, })

    def _training_pass(self, pass_no):
        """Run one training pass, update weights (store them for possible averaging),
        and store diagnostic values."""

        pass_start_time = time.time()
        self.reset_diagnostics()
        self.update_weights_sum()

        log_debug('\n***\nTR %05d:' % pass_no)

        rgen_max_iter = self._get_num_iters(pass_no, self.rival_gen_max_iter)
        rgen_max_defic_iter = self._get_num_iters(pass_no, self.rival_gen_max_defic_iter)
        rgen_beam_size = self.rival_gen_beam_size

        for tree_no in self.train_order:
            # obtain some 'rival', alternative incorrect candidates
            gold_da, gold_tree, gold_feats = self.train_das[tree_no], self.train_trees[tree_no], self.train_feats[tree_no]

            for strategy in self.rival_gen_strategy:
                rival_das, rival_trees, rival_feats = self._get_rival_candidates(tree_no, strategy, rgen_max_iter,
                                                                                 rgen_max_defic_iter, rgen_beam_size)
                cands = [gold_feats] + rival_feats

                # score them along with the right one
                scores = [self._score(cand) for cand in cands]
                top_cand_idx = scores.index(max(scores))
                top_rival_idx = scores[1:].index(max(scores[1:]))
                top_rival_tree = rival_trees[top_rival_idx]
                top_rival_da = rival_das[top_rival_idx]

                # find the top-scoring generated tree, evaluate against gold t-tree
                # (disregarding whether it was selected as the best one)
                self.evaluator.append(TreeNode(gold_tree), TreeNode(top_rival_tree), scores[0], max(scores[1:]))

                # debug print: candidate trees
                log_debug('TTREE-NO: %04d, SEL_CAND: %04d, LEN: %02d' % (tree_no, top_cand_idx, len(cands)))
                log_debug('SENT: %s' % self.train_sents[tree_no])
                log_debug('ALL CAND TREES:')
                for ttree, score in zip([gold_tree] + rival_trees, scores):
                    log_debug("%12.5f" % score, "\t", ttree)

                # update weights if the system doesn't give the highest score to the right one
                if top_cand_idx != 0:
                    self._update_weights(gold_da, top_rival_da, gold_tree, top_rival_tree,
                                         gold_feats, cands[top_cand_idx])

        # store a copy of the current weights for averaging
        self.store_iter_weights()

        # debug print: current weights and pass accuracy
        log_debug(self._feat_val_str(), '\n***')
        log_debug('PASS ACCURACY: %.3f' % self.evaluator.tree_accuracy())

        # print and return statistics
        self._print_pass_stats(pass_no, datetime.timedelta(seconds=(time.time() - pass_start_time)))

    def diffing_trees_with_scores(self, da, good_tree, bad_tree):
        """For debugging purposes. Return a printout of diffing trees between the chosen candidate
        and the gold tree, along with scores."""
        good_sts, bad_sts = good_tree.diffing_trees(bad_tree, symmetric=False)
        comm_st = good_tree.get_common_subtree(bad_tree)
        ret = 'Common subtree: %.3f' % self.score(comm_st, da) + "\t" + unicode(comm_st) + "\n"
        ret += "Good subtrees:\n"
        for good_st in good_sts:
            ret += "%.3f" % self.score(good_st, da) + "\t" + unicode(good_st) + "\n"
        ret += "Bad subtrees:\n"
        for bad_st in bad_sts:
            ret += "%.3f" % self.score(bad_st, da) + "\t" + unicode(bad_st) + "\n"
        return ret

    def _get_num_iters(self, cur_pass_no, iter_setting):
        """Return the maximum number of iterations (total/deficit) given the current pass.
        Used to keep track of variable iteration number setting in configuration."""
        if isinstance(iter_setting, (list, tuple)):
            ret = 0
            for set_pass_no, set_iter_no in iter_setting:
                if set_pass_no > cur_pass_no:
                    break
                ret = set_iter_no
            return ret
        else:
            return iter_setting  # a single setting for all passes

    def _print_pass_stats(self, pass_no, pass_duration):
        """Print pass statistics from internal evaluator fields and given pass duration."""
        log_info('Pass %05d -- tree-level accuracy: %.4f' % (pass_no, self.evaluator.tree_accuracy()))
        log_info(' * Generated trees NODE scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1())
        log_info(' * Generated trees DEP  scores: P: %.4f, R: %.4f, F: %.4f' %
                 self.evaluator.p_r_f1(EvalTypes.DEP))
        log_info(' * Gold tree BEST: %.4f, on CLOSE: %.4f, on ANY list: %.4f' %
                 self.lists_analyzer.stats())
        log_info(' * Tree size stats:\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s' %
                 self.evaluator.tree_size_stats())
        log_info(' * Common subtree stats:\n -- SIZE: %s\n -- ΔGLD: %s\n -- ΔPRD: %s' %
                 self.evaluator.common_subtree_stats())
        log_info(' * Score stats\n -- GOLD: %s\n -- PRED: %s\n -- DIFF: %s'
                 % self.evaluator.score_stats())
        log_info(' * Duration: %s' % str(pass_duration))

    def _feat_val_str(self, sep='\n', nonzero=False):
        """Return feature names and values for printing. To be overridden in base classes."""
        return ''

    def _get_rival_candidates(self, tree_no, strategy, max_iter, max_defic_iter, beam_size):
        """Generate some rival candidates for a DA and the correct (gold) tree,
        given the current rival generation strategy (self.rival_gen_strategy).

        TODO: checking for trees identical to the gold one slows down the process

        TODO: remove other generation strategies, remove support for multiple generated trees

        @param tree_no: the index of the current training data item (tree, DA)
        @rtype: tuple of two lists: one of TreeData's, one of arrays
        @return: an array of rival trees and an array of the corresponding features
        """
        da = self.train_das[tree_no]
        train_trees = self.train_trees

        rival_das, rival_trees, rival_feats = [], [], []

        if strategy != 'other_da':
            rival_das = [da] * self.rival_number

        # use current DA but change trees when computing features
        if strategy == 'other_inst':
            # use alternative indexes, avoid the correct one
            rival_idxs = map(lambda idx: len(train_trees) - 1 if idx == tree_no else idx,
                             rnd.sample(xrange(len(train_trees) - 1), self.rival_number))
            other_inst_trees = [train_trees[rival_idx] for rival_idx in rival_idxs]
            rival_trees.extend(other_inst_trees)
            rival_feats.extend([self._extract_feats(tree, da) for tree in other_inst_trees])

        # use the current gold tree but change DAs when computing features
        if strategy == 'other_da':
            rival_idxs = map(lambda idx: len(train_trees) - 1 if idx == tree_no else idx,
                             rnd.sample(xrange(len(train_trees) - 1), self.rival_number))
            other_inst_das = [self.train_das[rival_idx] for rival_idx in rival_idxs]
            rival_das.extend(other_inst_das)
            rival_trees.extend([self.train_trees[tree_no]] * self.rival_number)
            rival_feats.extend([self._extract_feats(self.train_trees[tree_no], da)
                                for da in other_inst_das])

        # candidates generated using the random planner (use the current DA)
        if strategy == 'random':
            random_trees = []
            while len(random_trees) < self.rival_number:
                tree = self.sampling_planner.generate_tree(da)
                if (tree != train_trees[tree_no]):  # don't generate trees identical to the gold one
                    random_trees.append(tree)
            rival_trees.extend(random_trees)
            rival_feats.extend([self._extract_feats(tree, da) for tree in random_trees])

        # candidates generated using the A*search planner, which uses this ranker with current
        # weights to guide the search, and the current DA as the input
        # TODO: use just one!, others are meaningless
        if strategy == 'gen_cur_weights':
            open_list, close_list = self.asearch_planner.run(da, max_iter, max_defic_iter, beam_size)
            self.lists_analyzer.append(train_trees[tree_no], open_list, close_list)
            gen_trees = []
            while close_list and len(gen_trees) < self.rival_number:
                tree = close_list.pop()[0]
                if tree != train_trees[tree_no]:
                    gen_trees.append(tree)
            rival_trees.extend(gen_trees[:self.rival_number])
            rival_feats.extend([self._extract_feats(tree, da)
                                for tree in gen_trees[:self.rival_number]])

        # return all resulting candidates
        return rival_das, rival_trees, rival_feats

    def get_weights(self):
        """Return the current ranker weights (parameters). To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights(self, w):
        """Set new ranker weights. To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_average(self, ws):
        """Set the weights as the average of the given array of weights (used in parallel training).
        To be overridden by derived classes."""
        raise NotImplementedError

    def store_iter_weights(self):
        """Remember the current weights to be used for averaging.
        To be overridden by derived classes."""
        raise NotImplementedError

    def set_weights_iter_average(self):
        """Set new weights as the average of all remembered weights. To be overridden by
        derived classes."""
        raise NotImplementedError

    def get_weights_sum(self):
        """Return weights size in order to weigh future promise against them."""
        raise NotImplementedError

    def update_weights_sum(self):
        """Update the current weights size for future promise weighining."""
        raise NotImplementedError

    def reset_diagnostics(self):
        """Reset the evaluation statistics (Evaluator and ASearchListsAnalyzer objects)."""
        self.evaluator = Evaluator()
        self.lists_analyzer = ASearchListsAnalyzer()

    def get_diagnostics(self):
        """Return the current evaluation statistics (a tuple of Evaluator and ASearchListsAnalyzer
        objects."""
        return self.evaluator, self.lists_analyzer

    def set_diagnostics_average(self, diags):
        """Given an array of evaluation statistics objects, average them an store in this ranker
        instance."""
        self.reset_diagnostics()
        for evaluator, lists_analyzer in diags:
            self.evaluator.merge(evaluator)
            self.lists_analyzer.merge(lists_analyzer)

    def get_future_promise(self, cand_tree):
        """Compute expected future promise for a tree."""
        w_sum = self.get_weights_sum()
        if self.future_promise_type == 'num_nodes':
            return w_sum * self.future_promise_weight * max(0, 10 - len(cand_tree))
        elif self.future_promise_type == 'norm_exp_children':
            return (self.candgen.get_future_promise(cand_tree) / len(cand_tree)) * w_sum * self.future_promise_weight
        elif self.future_promise_type == 'ands':
            prom = 0
            for idx, node in enumerate(cand_tree.nodes):
                if node.t_lemma == 'and':
                    num_kids = cand_tree.children_num(idx)
                    prom += max(0, 2 - num_kids)
            return prom * w_sum * self.future_promise_weight
        else:  # expected children (default)
            return self.candgen.get_future_promise(cand_tree) * w_sum * self.future_promise_weight

    def get_future_promise_all(self, cand_trees):
        """Array version of get_future_promise."""
        return [self.get_future_promise(cand_tree) for cand_tree in cand_trees]