Пример #1
0
    def __init__(self, cfg):
        super(Seq2SeqBase, self).__init__(cfg)
        # save the whole configuration for later use (save/load, construction of embedding
        # extractors)
        self.cfg = cfg

        # decoding options
        self.beam_size = cfg.get('beam_size', 1)
        self.sample_top_k = cfg.get('sample_top_k', 1)
        self.length_norm_weight = cfg.get('length_norm_weight', 0.0)
        self.context_bleu_weight = cfg.get('context_bleu_weight', 0.0)
        self.context_bleu_metric = cfg.get('context_bleu_metric', 'bleu')
        self.slot_err_stats = None

        self.classif_filter = None
        if 'classif_filter' in cfg:
            # use the specialized settings for the reranking classifier
            rerank_cfg = cfg['classif_filter']
            # plus, copy some settings from the main Seq2Seq module (so we're consistent)
            for setting in [
                    'use_tokens', 'embeddings_lowercase',
                    'embeddings_split_plurals'
            ]:
                if setting in cfg:
                    rerank_cfg[setting] = cfg[setting]
            self.classif_filter = RerankingClassifier(rerank_cfg)
            self.misfit_penalty = cfg.get('misfit_penalty', 100)
Пример #2
0
def rerank_cl_eval(args):
    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))
    ap.add_argument(
        '-l',
        '--language',
        type=str,
        help='Override classifier language (for t-tree input files)')
    ap.add_argument(
        '-s',
        '--selector',
        type=str,
        help='Override classifier selector (for t-tree input files)')
    ap.add_argument('fname_cl_model',
                    type=str,
                    help='Path to trained reranking classifier model')
    ap.add_argument('fname_test_da', type=str, help='Path to test DA file')
    ap.add_argument('fname_test_sent',
                    type=str,
                    help='Path to test trees/sentences file')
    args = ap.parse_args(args)

    log_info("Loading reranking classifier...")
    rerank_cl = RerankingClassifier.load_from_file(args.fname_cl_model)
    if args.language is not None:
        rerank_cl.language = args.language
    if args.selector is not None:
        rerank_cl.selector = args.selector

    log_info("Evaluating...")
    tot_len, dist = rerank_cl.evaluate_file(args.fname_test_da,
                                            args.fname_test_sent)
    log_info("Penalty: %d, Total DAIs %d." % (dist, tot_len))
Пример #3
0
def rerank_cl_eval(args):

    opts, files = getopt(args, 's:l:t')

    language = None
    selector = None
    for opt, arg in opts:
        if opt == '-l':
            language = arg
        elif opt == '-s':
            selector = arg

    if len(files) != 3:
        sys.exit("Invalid arguments.\n" + __doc__)
    fname_cl_model, fname_test_da, fname_test_sent = files

    log_info("Loading reranking classifier...")
    rerank_cl = RerankingClassifier.load_from_file(fname_cl_model)
    if language is not None:
        rerank_cl.language = language
    if selector is not None:
        rerank_cl.selector = selector

    log_info("Evaluating...")
    tot_len, dist = rerank_cl.evaluate_file(fname_test_da, fname_test_sent)
    log_info("Penalty: %d, Total DAIs %d." % (dist, tot_len))
Пример #4
0
    def load_from_file(model_fname):
        """Load the generator from a file (actually two files, one for configuration and one
        for the TensorFlow graph, which must be stored separately).

        @param model_fname: file name (for the configuration file); TF graph must be stored with a \
            different extension
        """
        log_info("Loading generator from %s..." % model_fname)
        with file_stream(model_fname, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            ret = Seq2SeqGen(cfg=data['cfg'])
            ret.load_all_settings(data)

        if ret.classif_filter:
            classif_filter_fname = re.sub(r'((.pickle)?(.gz)?)$',
                                          r'.tftreecl\1', model_fname)
            if os.path.isfile(classif_filter_fname):
                ret.classif_filter = RerankingClassifier.load_from_file(
                    classif_filter_fname)
            else:
                log_warn("Classification filter data not found, ignoring.")
                ret.classif_filter = False

        # re-build TF graph and restore the TF session
        tf_session_fname = re.sub(r'(.pickle)?(.gz)?$', '.tfsess', model_fname)
        ret._init_neural_network()
        ret.saver.restore(ret.session, tf_session_fname)

        return ret
Пример #5
0
    def load_from_file(model_fname):
        """Load the generator from a file (actually two files, one for configuration and one
        for the TensorFlow graph, which must be stored separately).

        @param model_fname: file name (for the configuration file); TF graph must be stored with a \
            different extension
        """
        log_info("Loading generator from %s..." % model_fname)
        with file_stream(model_fname, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            ret = Seq2SeqGen(cfg=data['cfg'])
            ret.load_all_settings(data)

        if ret.classif_filter:
            classif_filter_fname = re.sub(r'((.pickle)?(.gz)?)$', r'.tftreecl\1', model_fname)
            if os.path.isfile(classif_filter_fname):
                ret.classif_filter = RerankingClassifier.load_from_file(classif_filter_fname)
            else:
                log_warn("Classification filter data not found, ignoring.")
                ret.classif_filter = False

        # re-build TF graph and restore the TF session
        tf_session_fname = re.sub(r'(.pickle)?(.gz)?$', '.tfsess', model_fname)
        ret._init_neural_network()
        ret.saver.restore(ret.session, tf_session_fname)

        return ret
Пример #6
0
def rerank_cl_train(args):

    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))
    ap.add_argument(
        '-a',
        '--add-to-seq2seq',
        type=str,
        help=
        'Replace trained classifier in an existing seq2seq model (path to file)'
    )
    ap.add_argument('fname_config',
                    type=str,
                    help='Reranking classifier configuration file path')
    ap.add_argument('fname_da_train', type=str, help='Training DAs file path')
    ap.add_argument('fname_trees_train',
                    type=str,
                    help='Training trees/sentences file path')
    ap.add_argument('fname_cl_model',
                    type=str,
                    help='Path for the output trained model')
    args = ap.parse_args(args)

    if args.add_to_seq2seq:
        tgen = Seq2SeqBase.load_from_file(args.add_to_seq2seq)

    config = Config(args.fname_config)
    rerank_cl = RerankingClassifier(config)
    rerank_cl.train(args.fname_da_train, args.fname_trees_train)

    if args.add_to_seq2seq:
        tgen.classif_filter = rerank_cl
        tgen.save_to_file(args.fname_cl_model)
    else:
        rerank_cl.save_to_file(args.fname_cl_model)
Пример #7
0
def rerank_cl_train(args):

    opts, files = getopt(args, 'a:')

    load_seq2seq_model = None
    for opt, arg in opts:
        if opt == '-a':
            load_seq2seq_model = arg

    if len(files) != 4:
        sys.exit("Invalid arguments.\n" + __doc__)
    fname_config, fname_da_train, fname_trees_train, fname_cl_model = files

    if load_seq2seq_model:
        tgen = Seq2SeqBase.load_from_file(load_seq2seq_model)

    config = Config(fname_config)
    rerank_cl = RerankingClassifier(config)
    rerank_cl.train(fname_da_train, fname_trees_train)

    if load_seq2seq_model:
        tgen.classif_filter = rerank_cl
        tgen.save_to_file(fname_cl_model)
    else:
        rerank_cl.save_to_file(fname_cl_model)
Пример #8
0
    def build_ensemble(self, models, rerank_settings=None, rerank_params=None):
        """Build the ensemble model (build all networks and load their parameters).

        @param models: list of tuples (settings, parameter set) of all models in the ensemble
        @param rerank_settings:
        """

        for setting, parset in models:
            model = Seq2SeqGen(setting['cfg'])
            model.load_all_settings(setting)
            model._init_neural_network()
            model.set_model_params(parset)
            self.gens.append(model)

        # embedding IDs should be the same for all models, it is safe to use them directly
        self.da_embs = self.gens[0].da_embs
        self.tree_embs = self.gens[0].tree_embs

        if rerank_settings is not None:
            self.classif_filter = RerankingClassifier(cfg=rerank_settings['cfg'])
            self.classif_filter.load_all_settings(rerank_settings)
            self.classif_filter._init_neural_network()
            self.classif_filter.set_model_params(rerank_params)
Пример #9
0
def rerank_cl_train(args):

    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))
    ap.add_argument('-a', '--add-to-seq2seq', type=str,
                    help='Replace trained classifier in an existing seq2seq model (path to file)')
    ap.add_argument('fname_config', type=str, help='Reranking classifier configuration file path')
    ap.add_argument('fname_da_train', type=str, help='Training DAs file path')
    ap.add_argument('fname_trees_train', type=str, help='Training trees file path (must be trees!)')
    ap.add_argument('fname_cl_model', type=str, help='Path for the output trained model')
    args = ap.parse_args(args)

    if args.add_to_seq2seq:
        tgen = Seq2SeqBase.load_from_file(args.add_to_seq2seq)

    config = Config(args.fname_config)
    rerank_cl = RerankingClassifier(config)
    rerank_cl.train(args.fname_da_train, args.fname_trees_train)

    if args.add_to_seq2seq:
        tgen.classif_filter = rerank_cl
        tgen.save_to_file(args.fname_cl_model)
    else:
        rerank_cl.save_to_file(args.fname_cl_model)
Пример #10
0
def rerank_cl_eval(args):
    ap = ArgumentParser(prog=' '.join(sys.argv[0:2]))
    ap.add_argument('-l', '--language', type=str,
                    help='Override classifier language (for t-tree input files)')
    ap.add_argument('-s', '--selector', type=str,
                    help='Override classifier selector (for t-tree input files)')
    ap.add_argument('fname_cl_model', type=str, help='Path to trained reranking classifier model')
    ap.add_argument('fname_test_da', type=str, help='Path to test DA file')
    ap.add_argument('fname_test_sent', type=str, help='Path to test trees file (must be trees!)')
    args = ap.parse_args(args)

    log_info("Loading reranking classifier...")
    rerank_cl = RerankingClassifier.load_from_file(args.fname_cl_model)
    if args.language is not None:
        rerank_cl.language = args.language
    if args.selector is not None:
        rerank_cl.selector = args.selector

    log_info("Evaluating...")
    tot_len, dist = rerank_cl.evaluate_file(args.fname_test_da, args.fname_test_sent)
    log_info("Penalty: %d, Total DAIs %d." % (dist, tot_len))
Пример #11
0
    def __init__(self, cfg):
        super(Seq2SeqBase, self).__init__(cfg)
        # save the whole configuration for later use (save/load, construction of embedding
        # extractors)
        self.cfg = cfg

        # decoding options
        self.beam_size = cfg.get('beam_size', 1)
        self.sample_top_k = cfg.get('sample_top_k', 1)
        self.length_norm_weight = cfg.get('length_norm_weight', 0.0)
        self.context_bleu_weight = cfg.get('context_bleu_weight', 0.0)
        self.context_bleu_metric = cfg.get('context_bleu_metric', 'bleu')
        self.slot_err_stats = None

        self.classif_filter = None
        if 'classif_filter' in cfg:
            # use the specialized settings for the reranking classifier
            rerank_cfg = cfg['classif_filter']
            # plus, copy some settings from the main Seq2Seq module (so we're consistent)
            for setting in ['use_tokens', 'embeddings_lowercase', 'embeddings_split_plurals']:
                if setting in cfg:
                    rerank_cfg[setting] = cfg[setting]
            self.classif_filter = RerankingClassifier(rerank_cfg)
            self.misfit_penalty = cfg.get('misfit_penalty', 100)
Пример #12
0
class Seq2SeqBase(SentencePlanner):
    """A common ancestor for the Plain and Ensemble Seq2Seq generators (decoding methods only)."""
    def __init__(self, cfg):
        super(Seq2SeqBase, self).__init__(cfg)
        # save the whole configuration for later use (save/load, construction of embedding
        # extractors)
        self.cfg = cfg

        # decoding options
        self.beam_size = cfg.get('beam_size', 1)
        self.sample_top_k = cfg.get('sample_top_k', 1)
        self.length_norm_weight = cfg.get('length_norm_weight', 0.0)
        self.context_bleu_weight = cfg.get('context_bleu_weight', 0.0)
        self.context_bleu_metric = cfg.get('context_bleu_metric', 'bleu')
        self.slot_err_stats = None

        self.classif_filter = None
        if 'classif_filter' in cfg:
            # use the specialized settings for the reranking classifier
            rerank_cfg = cfg['classif_filter']
            # plus, copy some settings from the main Seq2Seq module (so we're consistent)
            for setting in [
                    'use_tokens', 'embeddings_lowercase',
                    'embeddings_split_plurals'
            ]:
                if setting in cfg:
                    rerank_cfg[setting] = cfg[setting]
            self.classif_filter = RerankingClassifier(rerank_cfg)
            self.misfit_penalty = cfg.get('misfit_penalty', 100)

    def process_das(self, das, gold_trees=None):
        """
        Process a list of input DAs, return the corresponding trees (using the generator
        network with current parameters).

        @param das: input DAs
        @param gold_trees: (optional) gold trees against which cost is computed
        @return: generated trees as `TreeData` instances, cost if `gold_trees` are given
        """
        # encoder inputs
        enc_inputs = cut_batch_into_steps(
            [self.da_embs.get_embeddings(da) for da in das])

        if self.beam_size > 1 and len(das) == 1:
            dec_output_ids = self._beam_search(enc_inputs, das[0])
            dec_cost = None
        else:
            dec_output_ids, dec_cost = self._greedy_decoding(
                enc_inputs, gold_trees)

        dec_trees = [
            self.tree_embs.ids_to_tree(ids)
            for ids in dec_output_ids.transpose()
        ]

        # return result (trees and optionally cost)
        if dec_cost is None:
            return dec_trees
        return dec_trees, dec_cost

    def _greedy_decoding(self, enc_inputs, gold_trees):
        """Run greedy decoding with the given encoder inputs; optionally use given gold trees
        as decoder inputs for cost computation."""

        # prepare decoder inputs (either fake, or true but used just for cost computation)
        if gold_trees is None:
            empty_tree_emb = self.tree_embs.get_embeddings(TreeData())
            dec_inputs = cut_batch_into_steps(
                [empty_tree_emb for _ in enc_inputs[0]])
        else:
            dec_inputs = cut_batch_into_steps(
                [self.tree_embs.get_embeddings(tree) for tree in gold_trees])

        # run the decoding per se
        dec_output_ids, dec_cost = self._get_greedy_decoder_output(
            enc_inputs, dec_inputs, compute_cost=gold_trees is not None)

        return dec_output_ids, dec_cost

    def _get_greedy_decoder_output(initial_state,
                                   enc_inputs,
                                   dec_inputs,
                                   compute_cost=False):
        raise NotImplementedError()

    class DecodingPath(object):
        """A decoding path to be used in beam search."""

        __slots__ = [
            'stop_token_id', 'dec_inputs', 'dec_states', 'logprob', '_length'
        ]

        def __init__(self,
                     stop_token_id,
                     dec_inputs=[],
                     dec_states=[],
                     logprob=0.0,
                     length=-1):
            self.stop_token_id = stop_token_id
            self.dec_inputs = list(dec_inputs)
            self.dec_states = list(dec_states)
            self.logprob = logprob
            self._length = length if length >= 0 else len(dec_inputs)

        def expand(self, max_variants, dec_out_probs, dec_state):
            """Expand the path with all possible outputs, updating the log probabilities.

            @param max_variants: expand to this number of variants at maximum, discard the less \
                probable ones
            @param dec_output: the decoder output scores for the current step
            @param dec_state: the decoder hidden state for the current step
            @return: an array of all possible continuations of this path
            """
            ret = []

            # select only up to max_variants most probable variants
            top_n_idx = np.argpartition(-dec_out_probs,
                                        max_variants)[:max_variants]

            for idx in top_n_idx:
                expanded = Seq2SeqGen.DecodingPath(self.stop_token_id,
                                                   self.dec_inputs,
                                                   self.dec_states,
                                                   self.logprob, len(self))
                if len(self) == len(
                        self.dec_inputs) and idx != self.stop_token_id:
                    expanded._length += 1
                expanded.logprob += np.log(dec_out_probs[idx])
                expanded.dec_inputs.append(np.array(idx, ndmin=1))
                expanded.dec_states.append(dec_state)
                ret.append(expanded)

            return ret

        def __len__(self):
            """Return decoding path length (number of decoder input tokens)."""
            return self._length

    def _beam_search(self, enc_inputs, da):
        """Run beam search decoding."""

        # true "batches" not implemented
        assert len(enc_inputs[0]) == 1

        # run greedy decoder for comparison (debugging purposes)
        log_debug("GREEDY DEC WOULD RETURN:\n" + " ".join(
            self.tree_embs.ids_to_strings([
                out_tok[0]
                for out_tok in self._greedy_decoding(enc_inputs, None)[0]
            ])))

        # initialize
        self._init_beam_search(enc_inputs)
        empty_tree_emb = self.tree_embs.get_embeddings(TreeData())
        dec_inputs = cut_batch_into_steps([empty_tree_emb])

        paths = [
            self.DecodingPath(stop_token_id=self.tree_embs.STOP,
                              dec_inputs=[dec_inputs[0]])
        ]

        # beam search steps
        for step in xrange(len(dec_inputs)):

            new_paths = []

            for path in paths:
                out_probs, st = self._beam_search_step(path.dec_inputs,
                                                       path.dec_states)
                new_paths.extend(path.expand(self.beam_size, out_probs, st))

            def cmp_func(p, q):
                """Length-weighted comparison of two paths' logprobs."""
                return cmp(p.logprob / (len(p)**self.length_norm_weight),
                           q.logprob / (len(q)**self.length_norm_weight))

            paths = sorted(new_paths, cmp=cmp_func,
                           reverse=True)[:self.beam_size]

            if all([p.dec_inputs[-1] == self.tree_embs.VOID for p in paths]):
                break  # stop decoding if we have reached the end in all paths

            log_debug(("\nBEAM SEARCH STEP %d\n" % step) +
                      "\n".join([("%f\t" % p.logprob) + " ".join(
                          self.tree_embs.ids_to_strings(
                              [inp[0] for inp in p.dec_inputs]))
                                 for p in paths]) + "\n")

        # rerank paths by their distance to the input DA
        if self.classif_filter or self.context_bleu_weight:
            paths = self._rerank_paths(paths, da)

        # measure slot error on the top k paths
        if self.slot_err_stats:
            for path in paths[:self.sample_top_k]:
                self.slot_err_stats.append(
                    da,
                    self.tree_embs.ids_to_strings(
                        [inp[0] for inp in path.dec_inputs]))

        # select the "best" path -- either the best, or one in top k
        if self.sample_top_k > 1:
            best_path = self._sample_path(paths[:self.sample_top_k])
        else:
            best_path = paths[0]

        # return just the best path (as token IDs)
        return np.array(best_path.dec_inputs)

    def _init_beam_search(self, enc_inputs):
        raise NotImplementedError()

    def _beam_search_step(self, dec_inputs, dec_states):
        raise NotImplementedError()

    def _rerank_paths(self, paths, da):
        """Rerank the n-best decoded paths according to the reranking classifier and/or
        BLEU against context."""

        trees = [
            self.tree_embs.ids_to_tree(
                np.array(path.dec_inputs).transpose()[0]) for path in paths
        ]

        # rerank using BLEU against context if set to do so
        if self.context_bleu_weight:
            bm = BLEUMeasure(max_ngram=2)
            bleus = []
            for path, tree in zip(paths, trees):
                bm.reset()
                bm.append([(n.t_lemma, None) for n in tree.nodes[1:]], [da[0]])
                bleu = (bm.ngram_precision() if self.context_bleu_metric
                        == 'ngram_prec' else bm.bleu())
                bleus.append(bleu)
                path.logprob += self.context_bleu_weight * bleu

            log_debug(("BLEU for context: %s\n\n" %
                       " ".join([form for form, _ in da[0]])) +
                      "\n".join([("%.5f\t" % b) +
                                 " ".join([n.t_lemma for n in t.nodes[1:]])
                                 for b, t in zip(bleus, trees)]))

        # add distances to logprob so that non-fitting will be heavily penalized
        if self.classif_filter:
            self.classif_filter.init_run(da)
            fits = self.classif_filter.dist_to_cur_da(trees)
            for path, fit in zip(paths, fits):
                path.logprob -= self.misfit_penalty * fit

            log_debug(("Misfits for DA: %s\n\n" % str(da)) +
                      "\n".join([("%.5f\t" % fit) + " ".join(
                          [unicode(n.t_lemma) for n in tree.nodes[1:]])
                                 for fit, tree in zip(fits, trees)]))

        # adjust paths for length (if set to do so)
        if self.length_norm_weight:
            for path in paths:
                path.logprob /= len(path)**self.length_norm_weight

        return sorted(paths,
                      cmp=lambda p, q: cmp(p.logprob, q.logprob),
                      reverse=True)

    def _sample_path(self, paths):
        """Sample one path from the top k paths, based on their probabilities."""

        # convert the logprobs to a probability distribution, proportionate to their sizes
        logprobs = [p.logprob for p in paths]
        max_logprob = max(logprobs)
        probs = [math.exp(l - max_logprob) for l in logprobs
                 ]  # discount to avoid underflow, result is unnormalized
        sum_prob = sum(probs)
        probs = [p / sum_prob for p in probs]  # normalized

        # select the path based on a draw from the uniform distribution
        draw = rnd.random()
        cum = 0.0  # building cumulative distribution function on-the-fly
        selected = -1
        for idx, prob in enumerate(probs):
            high = cum + prob
            if cum <= draw and draw < high:  # the draw has hit this index in the CDF
                selected = idx
                break
            cum = high

        return paths[selected]

    def generate_tree(self, da, gen_doc=None):
        """Generate one tree, saving it into the document provided (if applicable).

        @param da: the input DA
        @param gen_doc: the document where the tree should be saved (defaults to None)
        """
        # generate the tree
        log_debug("GENERATE TREE FOR DA: " + unicode(da))
        tree = self.process_das([da])[0]
        log_debug("RESULT: %s" % unicode(tree))
        # if requested, append the result to the "document"
        # just lists (generated tokens only, disregarding syntax; keep None for POS tags)
        if isinstance(gen_doc, list):
            # ignore tree technical root, take just "lemmas"
            gen_doc.append([(n.t_lemma, None) for n in tree.nodes[1:]])
        # full Pytreex documents (full trees)
        elif gen_doc:
            zone = self.get_target_zone(gen_doc)
            zone.ttree = tree.create_ttree()
            zone.sentence = unicode(da)
        # return the result
        return tree

    def init_slot_err_stats(self):
        """Initialize slot error statistics accumulator."""
        self.slot_err_stats = SlotErrAnalyzer()

    def get_slot_err_stats(self):
        """Return current slot error statistics, as a string."""
        return ("Slot error: %.2f (M: %d, S: %d, T: %d)" %
                (self.slot_err_stats.slot_error(), self.slot_err_stats.missing,
                 self.slot_err_stats.superfluous, self.slot_err_stats.total))

    @staticmethod
    def load_from_file(model_fname):
        """Detect correct model type (plain/ensemble) and start loading."""
        model_type = Seq2SeqGen  # default to plain generator
        with file_stream(model_fname, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            if isinstance(data, type):
                model_type = data

        return model_type.load_from_file(model_fname)
Пример #13
0
class Seq2SeqEnsemble(Seq2SeqBase):
    """Ensemble Sequence-to-Sequence models (averaging outputs of networks with different random
    initialization)."""

    def __init__(self, cfg):
        super(Seq2SeqEnsemble, self).__init__(cfg)

        self.gens = []

    def build_ensemble(self, models, rerank_settings=None, rerank_params=None):
        """Build the ensemble model (build all networks and load their parameters).

        @param models: list of tuples (settings, parameter set) of all models in the ensemble
        @param rerank_settings:
        """

        for setting, parset in models:
            model = Seq2SeqGen(setting['cfg'])
            model.load_all_settings(setting)
            model._init_neural_network()
            model.set_model_params(parset)
            self.gens.append(model)

        # embedding IDs should be the same for all models, it is safe to use them directly
        self.da_embs = self.gens[0].da_embs
        self.tree_embs = self.gens[0].tree_embs

        if rerank_settings is not None:
            self.classif_filter = RerankingClassifier(cfg=rerank_settings['cfg'])
            self.classif_filter.load_all_settings(rerank_settings)
            self.classif_filter._init_neural_network()
            self.classif_filter.set_model_params(rerank_params)

    def _get_greedy_decoder_output(self, enc_inputs, dec_inputs, compute_cost=False):
        """Run greedy decoding with the given inputs; return decoder outputs and the cost
        (if required). For ensemble decoding, the gready search is implemented as a beam
        search with a beam size of 1.

        @param enc_inputs: encoder inputs (list of token IDs)
        @param dec_inputs: decoder inputs (list of token IDs)
        @param compute_cost: if True, decoding cost is computed (the dec_inputs must be valid trees)
        @return a tuple of list of decoder outputs + decoding cost (None if not required)
        """
        # TODO batches and cost computation not implemented
        assert len(enc_inputs[0]) == 1 and not compute_cost

        self._init_beam_search(enc_inputs)

        # for simplicity, this is implemented exacly like a beam search, but with a path sized one
        empty_tree_emb = self.tree_embs.get_embeddings(TreeData())
        dec_inputs = cut_batch_into_steps([empty_tree_emb])
        path = self.DecodingPath(stop_token_id=self.tree_embs.STOP, dec_inputs=[dec_inputs[0]])

        for step in xrange(len(dec_inputs)):
            out_probs, st = self._beam_search_step(path.dec_inputs, path.dec_states)
            path = path.expand(1, out_probs, st)[0]

            if path.dec_inputs[-1] == self.tree_embs.VOID:
                break  # stop decoding if we have reached the end of path

        # return just token IDs, ignore cost computation here
        return np.array(path.dec_inputs), None

    def _init_beam_search(self, enc_inputs):
        """Initialize beam search for the current DA (with the given encoder inputs)
        for all member generators."""
        for gen in self.gens:
            gen._init_beam_search(enc_inputs)

    def _beam_search_step(self, dec_inputs, dec_states):
        """Run one step of beam search decoding with the given decoder inputs and
        (previous steps') outputs and states. Outputs are averaged over all member generators,
        states are kept separately."""
        ensemble_state = []
        ensemble_output = None

        for gen_no, gen in enumerate(self.gens):
            output, state = gen._beam_search_step(dec_inputs,
                                                  [state[gen_no] for state in dec_states])
            ensemble_state.append(state)
            output = np.exp(output) / np.sum(np.exp(output))
            if ensemble_output is None:
                ensemble_output = output
            else:
                ensemble_output += output

        ensemble_output /= float(len(self.gens))

        return ensemble_output, ensemble_state

    @staticmethod
    def load_from_file(model_fname):
        """Load the whole ensemble from a file (load settings and model parameters, then build the
        ensemble network)."""
        # TODO support for lexicalizer

        log_info("Loading ensemble generator from %s..." % model_fname)

        with file_stream(model_fname, 'rb', encoding=None) as fh:
            typeid = pickle.load(fh)
            if typeid != Seq2SeqEnsemble:
                raise ValueError('Wrong type identifier in file %s' % model_fname)
            cfg = pickle.load(fh)
            ret = Seq2SeqEnsemble(cfg)
            gens_dump = pickle.load(fh)
            if 'classif_filter' in cfg:
                rerank_settings = pickle.load(fh)
                rerank_params = pickle.load(fh)
            else:
                rerank_settings = None
                rerank_params = None

        ret.build_ensemble(gens_dump, rerank_settings, rerank_params)
        return ret

    def save_to_file(self, model_fname):
        """Save the whole ensemble into a file (get all settings and parameters, dump them in a
        pickle)."""
        # TODO support for lexicalizer

        log_info("Saving generator to %s..." % model_fname)
        with file_stream(model_fname, 'wb', encoding=None) as fh:
            pickle.dump(self.__class__, fh, protocol=pickle.HIGHEST_PROTOCOL)
            pickle.dump(self.cfg, fh, protocol=pickle.HIGHEST_PROTOCOL)

            gens_dump = []
            for gen in self.gens:
                setting = gen.get_all_settings()
                parset = gen.get_model_params()
                setting['classif_filter'] = self.classif_filter is not None
                gens_dump.append((setting, parset))

            pickle.dump(gens_dump, fh, protocol=pickle.HIGHEST_PROTOCOL)

            if self.classif_filter:
                pickle.dump(self.classif_filter.get_all_settings(), fh,
                            protocol=pickle.HIGHEST_PROTOCOL)
                pickle.dump(self.classif_filter.get_model_params(), fh,
                            protocol=pickle.HIGHEST_PROTOCOL)
Пример #14
0
class Seq2SeqBase(SentencePlanner):
    """A common ancestor for the Plain and Ensemble Seq2Seq generators (decoding methods only)."""

    def __init__(self, cfg):
        super(Seq2SeqBase, self).__init__(cfg)
        # save the whole configuration for later use (save/load, construction of embedding
        # extractors)
        self.cfg = cfg

        # decoding options
        self.beam_size = cfg.get('beam_size', 1)
        self.sample_top_k = cfg.get('sample_top_k', 1)
        self.length_norm_weight = cfg.get('length_norm_weight', 0.0)
        self.context_bleu_weight = cfg.get('context_bleu_weight', 0.0)
        self.context_bleu_metric = cfg.get('context_bleu_metric', 'bleu')
        self.slot_err_stats = None

        self.classif_filter = None
        if 'classif_filter' in cfg:
            # use the specialized settings for the reranking classifier
            rerank_cfg = cfg['classif_filter']
            # plus, copy some settings from the main Seq2Seq module (so we're consistent)
            for setting in ['use_tokens', 'embeddings_lowercase', 'embeddings_split_plurals']:
                if setting in cfg:
                    rerank_cfg[setting] = cfg[setting]
            self.classif_filter = RerankingClassifier(rerank_cfg)
            self.misfit_penalty = cfg.get('misfit_penalty', 100)

    def process_das(self, das, gold_trees=None):
        """
        Process a list of input DAs, return the corresponding trees (using the generator
        network with current parameters).

        @param das: input DAs
        @param gold_trees: (optional) gold trees against which cost is computed
        @return: generated trees as `TreeData` instances, cost if `gold_trees` are given
        """
        # encoder inputs
        enc_inputs = cut_batch_into_steps([self.da_embs.get_embeddings(da)
                                           for da in das])

        if self.beam_size > 1 and len(das) == 1:
            dec_output_ids = self._beam_search(enc_inputs, das[0])
            dec_cost = None
        else:
            dec_output_ids, dec_cost = self._greedy_decoding(enc_inputs, gold_trees)

        dec_trees = [self.tree_embs.ids_to_tree(ids) for ids in dec_output_ids.transpose()]

        # return result (trees and optionally cost)
        if dec_cost is None:
            return dec_trees
        return dec_trees, dec_cost

    def _greedy_decoding(self, enc_inputs, gold_trees):
        """Run greedy decoding with the given encoder inputs; optionally use given gold trees
        as decoder inputs for cost computation."""

        # prepare decoder inputs (either fake, or true but used just for cost computation)
        if gold_trees is None:
            empty_tree_emb = self.tree_embs.get_embeddings(TreeData())
            dec_inputs = cut_batch_into_steps([empty_tree_emb for _ in enc_inputs[0]])
        else:
            dec_inputs = cut_batch_into_steps([self.tree_embs.get_embeddings(tree)
                                               for tree in gold_trees])

        # run the decoding per se
        dec_output_ids, dec_cost = self._get_greedy_decoder_output(
                enc_inputs, dec_inputs, compute_cost=gold_trees is not None)

        return dec_output_ids, dec_cost

    def _get_greedy_decoder_output(initial_state, enc_inputs, dec_inputs, compute_cost=False):
        raise NotImplementedError()

    class DecodingPath(object):
        """A decoding path to be used in beam search."""

        __slots__ = ['stop_token_id', 'dec_inputs', 'dec_states', 'logprob', '_length']

        def __init__(self, stop_token_id, dec_inputs=[], dec_states=[], logprob=0.0, length=-1):
            self.stop_token_id = stop_token_id
            self.dec_inputs = list(dec_inputs)
            self.dec_states = list(dec_states)
            self.logprob = logprob
            self._length = length if length >= 0 else len(dec_inputs)

        def expand(self, max_variants, dec_out_probs, dec_state):
            """Expand the path with all possible outputs, updating the log probabilities.

            @param max_variants: expand to this number of variants at maximum, discard the less \
                probable ones
            @param dec_output: the decoder output scores for the current step
            @param dec_state: the decoder hidden state for the current step
            @return: an array of all possible continuations of this path
            """
            ret = []

            # select only up to max_variants most probable variants
            top_n_idx = np.argpartition(-dec_out_probs, max_variants)[:max_variants]

            for idx in top_n_idx:
                expanded = Seq2SeqGen.DecodingPath(self.stop_token_id,
                                                   self.dec_inputs, self.dec_states, self.logprob,
                                                   len(self))
                if len(self) == len(self.dec_inputs) and idx != self.stop_token_id:
                    expanded._length += 1
                expanded.logprob += np.log(dec_out_probs[idx])
                expanded.dec_inputs.append(np.array(idx, ndmin=1))
                expanded.dec_states.append(dec_state)
                ret.append(expanded)

            return ret

        def __len__(self):
            """Return decoding path length (number of decoder input tokens)."""
            return self._length

    def _beam_search(self, enc_inputs, da):
        """Run beam search decoding."""

        # true "batches" not implemented
        assert len(enc_inputs[0]) == 1

        # run greedy decoder for comparison (debugging purposes)
        log_debug("GREEDY DEC WOULD RETURN:\n" +
                  " ".join(self.tree_embs.ids_to_strings(
                      [out_tok[0] for out_tok in self._greedy_decoding(enc_inputs, None)[0]])))

        # initialize
        self._init_beam_search(enc_inputs)
        empty_tree_emb = self.tree_embs.get_embeddings(TreeData())
        dec_inputs = cut_batch_into_steps([empty_tree_emb])

        paths = [self.DecodingPath(stop_token_id=self.tree_embs.STOP, dec_inputs=[dec_inputs[0]])]

        # beam search steps
        for step in xrange(len(dec_inputs)):

            new_paths = []

            for path in paths:
                out_probs, st = self._beam_search_step(path.dec_inputs, path.dec_states)
                new_paths.extend(path.expand(self.beam_size, out_probs, st))

            def cmp_func(p, q):
                """Length-weighted comparison of two paths' logprobs."""
                return cmp(p.logprob / (len(p) ** self.length_norm_weight),
                           q.logprob / (len(q) ** self.length_norm_weight))

            paths = sorted(new_paths, cmp=cmp_func, reverse=True)[:self.beam_size]

            if all([p.dec_inputs[-1] == self.tree_embs.VOID for p in paths]):
                break  # stop decoding if we have reached the end in all paths

            log_debug(("\nBEAM SEARCH STEP %d\n" % step) +
                      "\n".join([("%f\t" % p.logprob) +
                                 " ".join(self.tree_embs.ids_to_strings([inp[0] for inp in p.dec_inputs]))
                                 for p in paths]) + "\n")

        # rerank paths by their distance to the input DA
        if self.classif_filter or self.context_bleu_weight:
            paths = self._rerank_paths(paths, da)

        # measure slot error on the top k paths
        if self.slot_err_stats:
            for path in paths[:self.sample_top_k]:
                self.slot_err_stats.append(
                        da, self.tree_embs.ids_to_strings([inp[0] for inp in path.dec_inputs]))

        # select the "best" path -- either the best, or one in top k
        if self.sample_top_k > 1:
            best_path = self._sample_path(paths[:self.sample_top_k])
        else:
            best_path = paths[0]

        # return just the best path (as token IDs)
        return np.array(best_path.dec_inputs)

    def _init_beam_search(self, enc_inputs):
        raise NotImplementedError()

    def _beam_search_step(self, dec_inputs, dec_states):
        raise NotImplementedError()

    def _rerank_paths(self, paths, da):
        """Rerank the n-best decoded paths according to the reranking classifier and/or
        BLEU against context."""

        trees = [self.tree_embs.ids_to_tree(np.array(path.dec_inputs).transpose()[0])
                 for path in paths]

        # rerank using BLEU against context if set to do so
        if self.context_bleu_weight:
            bm = BLEUMeasure(max_ngram=2)
            bleus = []
            for path, tree in zip(paths, trees):
                bm.reset()
                bm.append([(n.t_lemma, None) for n in tree.nodes[1:]], [da[0]])
                bleu = (bm.ngram_precision()
                        if self.context_bleu_metric == 'ngram_prec'
                        else bm.bleu())
                bleus.append(bleu)
                path.logprob += self.context_bleu_weight * bleu

            log_debug(("BLEU for context: %s\n\n" % " ".join([form for form, _ in da[0]])) +
                      "\n".join([("%.5f\t" % b) + " ".join([n.t_lemma for n in t.nodes[1:]])
                                 for b, t in zip(bleus, trees)]))

        # add distances to logprob so that non-fitting will be heavily penalized
        if self.classif_filter:
            self.classif_filter.init_run(da)
            fits = self.classif_filter.dist_to_cur_da(trees)
            for path, fit in zip(paths, fits):
                path.logprob -= self.misfit_penalty * fit

            log_debug(("Misfits for DA: %s\n\n" % str(da)) +
                      "\n".join([("%.5f\t" % fit) +
                                 " ".join([unicode(n.t_lemma) for n in tree.nodes[1:]])
                                 for fit, tree in zip(fits, trees)]))

        # adjust paths for length (if set to do so)
        if self.length_norm_weight:
            for path in paths:
                path.logprob /= len(path) ** self.length_norm_weight

        return sorted(paths, cmp=lambda p, q: cmp(p.logprob, q.logprob), reverse=True)

    def _sample_path(self, paths):
        """Sample one path from the top k paths, based on their probabilities."""

        # convert the logprobs to a probability distribution, proportionate to their sizes
        logprobs = [p.logprob for p in paths]
        max_logprob = max(logprobs)
        probs = [math.exp(l - max_logprob) for l in logprobs]  # discount to avoid underflow, result is unnormalized
        sum_prob = sum(probs)
        probs = [p / sum_prob for p in probs]  # normalized

        # select the path based on a draw from the uniform distribution
        draw = rnd.random()
        cum = 0.0  # building cumulative distribution function on-the-fly
        selected = -1
        for idx, prob in enumerate(probs):
            high = cum + prob
            if cum <= draw and draw < high:  # the draw has hit this index in the CDF
                selected = idx
                break
            cum = high

        return paths[selected]

    def generate_tree(self, da, gen_doc=None):
        """Generate one tree, saving it into the document provided (if applicable).

        @param da: the input DA
        @param gen_doc: the document where the tree should be saved (defaults to None)
        """
        # generate the tree
        log_debug("GENERATE TREE FOR DA: " + unicode(da))
        tree = self.process_das([da])[0]
        log_debug("RESULT: %s" % unicode(tree))
        # if requested, append the result to the "document"
        # just lists (generated tokens only, disregarding syntax; keep None for POS tags)
        if isinstance(gen_doc, list):
            # ignore tree technical root, take just "lemmas"
            gen_doc.append([(n.t_lemma, None) for n in tree.nodes[1:]])
        # full Pytreex documents (full trees)
        elif gen_doc:
            zone = self.get_target_zone(gen_doc)
            zone.ttree = tree.create_ttree()
            zone.sentence = unicode(da)
        # return the result
        return tree

    def init_slot_err_stats(self):
        """Initialize slot error statistics accumulator."""
        self.slot_err_stats = SlotErrAnalyzer()

    def get_slot_err_stats(self):
        """Return current slot error statistics, as a string."""
        return ("Slot error: %.2f (M: %d, S: %d, T: %d)" %
                (self.slot_err_stats.slot_error(), self.slot_err_stats.missing,
                 self.slot_err_stats.superfluous, self.slot_err_stats.total))

    @staticmethod
    def load_from_file(model_fname):
        """Detect correct model type (plain/ensemble) and start loading."""
        model_type = Seq2SeqGen  # default to plain generator
        with file_stream(model_fname, 'rb', encoding=None) as fh:
            data = pickle.load(fh)
            if isinstance(data, type):
                model_type = data

        return model_type.load_from_file(model_fname)