示例#1
0
def create_decoder():
    """Creates the ``Decoder`` instance. This specifies the search 
    strategy used to traverse the space spanned by the predictors. This
    method relies on the global ``args`` variable.
    
    TODO: Refactor to avoid long argument lists
    
    Returns:
        Decoder. Instance of the search strategy
    """

    # Create decoder instance and add predictors
    if args.decoder == "greedy":
        decoder = GreedyDecoder(args)
    elif args.decoder == "beam":
        decoder = BeamDecoder(args, args.hypo_recombination, args.beam,
                              args.pure_heuristic_scores,
                              args.decoder_diversity_factor,
                              args.early_stopping)
    elif args.decoder == "multisegbeam":
        decoder = MultisegBeamDecoder(args, args.hypo_recombination, args.beam,
                                      args.multiseg_tokenizations,
                                      args.early_stopping, args.max_word_len)
    elif args.decoder == "syncbeam":
        decoder = SyncBeamDecoder(args, args.hypo_recombination, args.beam,
                                  args.pure_heuristic_scores,
                                  args.decoder_diversity_factor,
                                  args.early_stopping, args.sync_symbol,
                                  args.max_word_len)
    elif args.decoder == "dfs":
        decoder = DFSDecoder(args, args.early_stopping,
                             args.max_node_expansions)
    elif args.decoder == "restarting":
        decoder = RestartingDecoder(args, args.hypo_recombination,
                                    args.max_node_expansions,
                                    args.low_decoder_memory,
                                    args.restarting_node_score,
                                    args.stochastic_decoder,
                                    args.decode_always_single_step)
    elif args.decoder == "bow":
        decoder = BOWDecoder(args, args.hypo_recombination,
                             args.max_node_expansions, args.stochastic_decoder,
                             args.early_stopping,
                             args.decode_always_single_step)
    elif args.decoder == "flip":
        decoder = FlipDecoder(args, args.trg_test, args.max_node_expansions,
                              args.early_stopping, args.flip_strategy)
    elif args.decoder == "bigramgreedy":
        decoder = BigramGreedyDecoder(args, args.trg_test,
                                      args.max_node_expansions,
                                      args.early_stopping)
    elif args.decoder == "bucket":
        decoder = BucketDecoder(
            args, args.hypo_recombination, args.max_node_expansions,
            args.low_decoder_memory, args.beam, args.pure_heuristic_scores,
            args.decoder_diversity_factor, args.early_stopping,
            args.stochastic_decoder, args.bucket_selector,
            args.bucket_score_strategy, args.collect_statistics)
    elif args.decoder == "astar":
        decoder = AstarDecoder(args, args.beam, args.pure_heuristic_scores,
                               args.early_stopping, max(1, args.nbest))
    elif args.decoder == "vanilla":
        decoder = get_nmt_vanilla_decoder(
            args, args.nmt_path,
            _parse_config_param("nmt_config", get_default_nmt_config()))
        args.predictors = "vanilla"
    else:
        logging.fatal("Decoder %s not available. Please double-check the "
                      "--decoder parameter." % args.decoder)
    add_predictors(decoder)

    # Add heuristics for search strategies like A*
    if args.heuristics:
        add_heuristics(decoder)

    # Update start sentence id if necessary
    if args.range:
        idx, _ = args.range.split(":") if (":" in args.range) else (args.range,
                                                                    0)
        decoder.set_start_sen_id(int(idx) -
                                 1)  # -1 because indices start with 1
    return decoder
示例#2
0
文件: bow.py 项目: ucam-smt/sgnmt
 def __init__(self,
              main_decoder,
              hypo_recombination,
              trg_test_file, 
              accept_subsets=False,
              accept_duplicates=False,
              heuristic_scores_file="",
              collect_stats_strategy='best',
              heuristic_add_consumed = False, 
              heuristic_add_remaining = True,
              diversity_heuristic_factor = -1.0,
              equivalence_vocab=-1):
     """Creates a new bag-of-words predictor with pre search
     
     Args:
         main_decoder (Decoder): Reference to the main decoder
                                 instance, used to fetch the predictors
         hypo_recombination (bool): Activates hypo recombination for the
                                    pre decoder 
         trg_test_file (string): Path to the plain text file with 
                                 the target sentences. Must have the
                                 same number of lines as the number
                                 of source sentences to decode. The 
                                 word order in the target sentences
                                 is not relevant for this predictor.
         accept_subsets (bool): If true, this predictor permits
                                    EOS even if the bag is not fully
                                    consumed yet
         accept_duplicates (bool): If true, counts are not updated
                                   when a word is consumed. This
                                   means that we allow a word in a
                                   bag to appear multiple times
         heuristic_scores_file (string): Path to the unigram scores 
                                         which are used if this 
                                         predictor estimates future
                                         costs
         collect_stats_strategy (string): best, full, or all. Defines 
                                          how unigram estimates are 
                                          collected for heuristic 
         heuristic_add_consumed (bool): Set to true to add the 
                                        difference between actual
                                        partial score and unigram
                                        estimates of consumed words
                                        to the predictor heuristic
         heuristic_add_remaining (bool): Set to true to add the sum
                                         of unigram scores of words
                                         remaining in the bag to the
                                         predictor heuristic
         equivalence_vocab (int): If positive, predictor states are
                                  considered equal if the the 
                                  remaining words within that vocab
                                  and OOVs regarding this vocab are
                                  the same. Only relevant when using
                                  hypothesis recombination
     """
     self.main_decoder = main_decoder
     self.pre_decoder = BeamDecoder(CLOSED_VOCAB_SCORE_NORM_NONE,
                                    main_decoder.max_len_factor,
                                    hypo_recombination,
                                    10)
     self.pre_decoder.combine_posteriors = main_decoder.combine_posteriors 
     super(BagOfWordsSearchPredictor, self).__init__(trg_test_file, 
                                                     accept_subsets,
                                                     accept_duplicates,
                                                     heuristic_scores_file,
                                                     collect_stats_strategy,
                                                     heuristic_add_consumed, 
                                                     heuristic_add_remaining,
                                                     diversity_heuristic_factor,
                                                     equivalence_vocab)
     self.pre_mode = False
示例#3
0
文件: bow.py 项目: ucam-smt/sgnmt
class BagOfWordsSearchPredictor(BagOfWordsPredictor):
    """Combines the bag-of-words predictor with a proxy decoding pass
    which creates a skeleton translation.
    """
    
    def __init__(self,
                 main_decoder,
                 hypo_recombination,
                 trg_test_file, 
                 accept_subsets=False,
                 accept_duplicates=False,
                 heuristic_scores_file="",
                 collect_stats_strategy='best',
                 heuristic_add_consumed = False, 
                 heuristic_add_remaining = True,
                 diversity_heuristic_factor = -1.0,
                 equivalence_vocab=-1):
        """Creates a new bag-of-words predictor with pre search
        
        Args:
            main_decoder (Decoder): Reference to the main decoder
                                    instance, used to fetch the predictors
            hypo_recombination (bool): Activates hypo recombination for the
                                       pre decoder 
            trg_test_file (string): Path to the plain text file with 
                                    the target sentences. Must have the
                                    same number of lines as the number
                                    of source sentences to decode. The 
                                    word order in the target sentences
                                    is not relevant for this predictor.
            accept_subsets (bool): If true, this predictor permits
                                       EOS even if the bag is not fully
                                       consumed yet
            accept_duplicates (bool): If true, counts are not updated
                                      when a word is consumed. This
                                      means that we allow a word in a
                                      bag to appear multiple times
            heuristic_scores_file (string): Path to the unigram scores 
                                            which are used if this 
                                            predictor estimates future
                                            costs
            collect_stats_strategy (string): best, full, or all. Defines 
                                             how unigram estimates are 
                                             collected for heuristic 
            heuristic_add_consumed (bool): Set to true to add the 
                                           difference between actual
                                           partial score and unigram
                                           estimates of consumed words
                                           to the predictor heuristic
            heuristic_add_remaining (bool): Set to true to add the sum
                                            of unigram scores of words
                                            remaining in the bag to the
                                            predictor heuristic
            equivalence_vocab (int): If positive, predictor states are
                                     considered equal if the the 
                                     remaining words within that vocab
                                     and OOVs regarding this vocab are
                                     the same. Only relevant when using
                                     hypothesis recombination
        """
        self.main_decoder = main_decoder
        self.pre_decoder = BeamDecoder(CLOSED_VOCAB_SCORE_NORM_NONE,
                                       main_decoder.max_len_factor,
                                       hypo_recombination,
                                       10)
        self.pre_decoder.combine_posteriors = main_decoder.combine_posteriors 
        super(BagOfWordsSearchPredictor, self).__init__(trg_test_file, 
                                                        accept_subsets,
                                                        accept_duplicates,
                                                        heuristic_scores_file,
                                                        collect_stats_strategy,
                                                        heuristic_add_consumed, 
                                                        heuristic_add_remaining,
                                                        diversity_heuristic_factor,
                                                        equivalence_vocab)
        self.pre_mode = False
    
    def predict_next(self):
        """If in ``pre_mode``, pass through to super class. Otherwise,
        scan skeleton 
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).predict_next()
        if not self.bag: # Empty bag
            return {utils.EOS_ID : 0.0}
        ret = {w : 0.0 for w in self.missing.iterkeys()}
        if self.accept_subsets:
            ret[utils.EOS_ID] = 0.0
        if self.skeleton_pos < len(self.skeleton):
            ret[self.skeleton[self.skeleton_pos]] = 0.0
        return ret
    
    def initialize(self, src_sentence):
        """If in ``pre_mode``, pass through to super class. Otherwise,
        initialize skeleton. 
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).initialize(src_sentence)
        self.pre_mode = True
        old_accept_subsets = self.accept_subsets
        old_accept_duplicates = self.accept_duplicates
        self.accept_subsets = True
        self.accept_duplicates = True
        self.pre_decoder.predictors = self.main_decoder.predictors
        self.pre_decoder.current_sen_id = self.main_decoder.current_sen_id - 1
        hypos = self.pre_decoder.decode(src_sentence)
        score = INF
        if not hypos:
            logging.warn("No hypothesis found by the pre decoder. Effectively "
                         "reducing bowsearch predictor to bow predictor.")
            self.skeleton = []
        else:
            self.skeleton = hypos[0].trgt_sentence
            score = hypos[0].total_score
            if self.skeleton and self.skeleton[-1] -- utils.EOS_ID:
                self.skeleton = self.skeleton[:-1] # Remove EOS
        self.skeleton_pos = 0
        self.accept_subsets = old_accept_subsets
        self.accept_duplicates = old_accept_duplicates
        self._set_up_full_mode()
        logging.debug("BOW Skeleton (score=%f missing=%d): %s" % (
                                          score,
                                          sum(self.missing.values()),
                                          self.skeleton))
        self.main_decoder.current_sen_id -= 1
        self.main_decoder.initialize_predictors(src_sentence)
        self.pre_mode = False
    
    def _set_up_full_mode(self):
        """This method initializes ``missing`` by using
        ``self.skeleton`` and ``self.full_bag`` and removes
        duplicates from ``self.skeleton``.
        """
        self.bag = dict(self.full_bag)
        missing = dict(self.full_bag)
        skeleton_no_duplicates = []
        for word in self.skeleton:
            if missing[word] > 0:
                missing[word] -= 1
                skeleton_no_duplicates.append(word)
        self.skeleton = skeleton_no_duplicates
        self.missing = {w: cnt for w, cnt in missing.iteritems() if cnt > 0}
        
    def consume(self, word):
        """Calls super class ``consume``. If not in ``pre_mode``,
        update skeleton info. 
        
        Args:
            word (int): Next word to consume
        """
        super(BagOfWordsSearchPredictor, self).consume(word)
        if self.pre_mode:
            return
        if (self.skeleton_pos < len(self.skeleton) 
                 and word == self.skeleton[self.skeleton_pos]):
            self.skeleton_pos += 1
        elif word in self.missing:
            self.missing[word] -= 1
            if self.missing[word] <= 0:
                del self.missing[word]
    
    def get_state(self):
        """If in pre_mode, state of this predictor is the current bag
        Otherwise, its the bag plus skeleton state
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).get_state()
        return self.bag, self.skeleton_pos, self.missing
    
    def set_state(self, state):
        """If in pre_mode, state of this predictor is the current bag
        Otherwise, its the bag plus skeleton state
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).set_state(state)
        self.bag, self.skeleton_pos, self.missing = state
    
    def is_equal(self, state1, state2):
        """Returns true if the bag and the skeleton states are the same
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).is_equal(state1, 
                                                                   state2)
        return super(BagOfWordsSearchPredictor, self).is_equal(state1[0], 
                                                               state2[0])
示例#4
0
def create_decoder():
    """Creates the ``Decoder`` instance. This specifies the search 
    strategy used to traverse the space spanned by the predictors. This
    method relies on the global ``args`` variable.
    
    TODO: Refactor to avoid long argument lists
    
    Returns:
        Decoder. Instance of the search strategy
    """
    # Create decoder instance and add predictors
    decoder = None
    try:
        if args.decoder == "greedy":
            decoder = GreedyDecoder(args)
        elif args.decoder == "beam":
            decoder = BeamDecoder(args)
        elif args.decoder == "multisegbeam":
            decoder = MultisegBeamDecoder(args,
                                          args.hypo_recombination,
                                          args.beam,
                                          args.multiseg_tokenizations,
                                          args.early_stopping,
                                          args.max_word_len)
        elif args.decoder == "syncbeam":
            decoder = SyncBeamDecoder(args)
        elif args.decoder == "mbrbeam":
            decoder = MBRBeamDecoder(args)
        elif args.decoder == "sepbeam":
            decoder = SepBeamDecoder(args)
        elif args.decoder == "syntaxbeam":
            decoder = SyntaxBeamDecoder(args)
        elif args.decoder == "combibeam":
            decoder = CombiBeamDecoder(args)
        elif args.decoder == "dfs":
            decoder = DFSDecoder(args)
        elif args.decoder == "restarting":
            decoder = RestartingDecoder(args,
                                        args.hypo_recombination,
                                        args.max_node_expansions,
                                        args.low_decoder_memory,
                                        args.restarting_node_score,
                                        args.stochastic_decoder,
                                        args.decode_always_single_step)
        elif args.decoder == "bow":
            decoder = BOWDecoder(args)
        elif args.decoder == "flip":
            decoder = FlipDecoder(args)
        elif args.decoder == "bigramgreedy":
            decoder = BigramGreedyDecoder(args)
        elif args.decoder == "bucket":
            decoder = BucketDecoder(args,
                                    args.hypo_recombination,
                                    args.max_node_expansions,
                                    args.low_decoder_memory,
                                    args.beam,
                                    args.pure_heuristic_scores,
                                    args.decoder_diversity_factor,
                                    args.early_stopping,
                                    args.stochastic_decoder,
                                    args.bucket_selector,
                                    args.bucket_score_strategy,
                                    args.collect_statistics)
        elif args.decoder == "astar":
            decoder = AstarDecoder(args)
        elif args.decoder == "vanilla":
            decoder = construct_nmt_vanilla_decoder()
            args.predictors = "vanilla"
        else:
            logging.fatal("Decoder %s not available. Please double-check the "
                          "--decoder parameter." % args.decoder)
    except Exception as e:
        logging.fatal("An %s has occurred while initializing the decoder: %s"
                      " Stack trace: %s" % (sys.exc_info()[0],
                                            e,
                                            traceback.format_exc()))
    if decoder is None:
        sys.exit("Could not initialize decoder.")
    add_predictors(decoder)
    # Add heuristics for search strategies like A*
    if args.heuristics:
        add_heuristics(decoder)
    return decoder
示例#5
0
 def __init__(self,
              main_decoder,
              hypo_recombination,
              trg_test_file,
              accept_subsets=False,
              accept_duplicates=False,
              heuristic_scores_file="",
              collect_stats_strategy='best',
              heuristic_add_consumed=False,
              heuristic_add_remaining=True,
              diversity_heuristic_factor=-1.0,
              equivalence_vocab=-1):
     """Creates a new bag-of-words predictor with pre search
     
     Args:
         main_decoder (Decoder): Reference to the main decoder
                                 instance, used to fetch the predictors
         hypo_recombination (bool): Activates hypo recombination for the
                                    pre decoder 
         trg_test_file (string): Path to the plain text file with 
                                 the target sentences. Must have the
                                 same number of lines as the number
                                 of source sentences to decode. The 
                                 word order in the target sentences
                                 is not relevant for this predictor.
         accept_subsets (bool): If true, this predictor permits
                                    EOS even if the bag is not fully
                                    consumed yet
         accept_duplicates (bool): If true, counts are not updated
                                   when a word is consumed. This
                                   means that we allow a word in a
                                   bag to appear multiple times
         heuristic_scores_file (string): Path to the unigram scores 
                                         which are used if this 
                                         predictor estimates future
                                         costs
         collect_stats_strategy (string): best, full, or all. Defines 
                                          how unigram estimates are 
                                          collected for heuristic 
         heuristic_add_consumed (bool): Set to true to add the 
                                        difference between actual
                                        partial score and unigram
                                        estimates of consumed words
                                        to the predictor heuristic
         heuristic_add_remaining (bool): Set to true to add the sum
                                         of unigram scores of words
                                         remaining in the bag to the
                                         predictor heuristic
         equivalence_vocab (int): If positive, predictor states are
                                  considered equal if the the 
                                  remaining words within that vocab
                                  and OOVs regarding this vocab are
                                  the same. Only relevant when using
                                  hypothesis recombination
     """
     self.main_decoder = main_decoder
     self.pre_decoder = BeamDecoder(CLOSED_VOCAB_SCORE_NORM_NONE,
                                    main_decoder.max_len_factor,
                                    hypo_recombination, 10)
     self.pre_decoder.combine_posteriors = main_decoder.combine_posteriors
     super(BagOfWordsSearchPredictor,
           self).__init__(trg_test_file, accept_subsets, accept_duplicates,
                          heuristic_scores_file, collect_stats_strategy,
                          heuristic_add_consumed, heuristic_add_remaining,
                          diversity_heuristic_factor, equivalence_vocab)
     self.pre_mode = False
示例#6
0
class BagOfWordsSearchPredictor(BagOfWordsPredictor):
    """Combines the bag-of-words predictor with a proxy decoding pass
    which creates a skeleton translation.
    """
    def __init__(self,
                 main_decoder,
                 hypo_recombination,
                 trg_test_file,
                 accept_subsets=False,
                 accept_duplicates=False,
                 heuristic_scores_file="",
                 collect_stats_strategy='best',
                 heuristic_add_consumed=False,
                 heuristic_add_remaining=True,
                 diversity_heuristic_factor=-1.0,
                 equivalence_vocab=-1):
        """Creates a new bag-of-words predictor with pre search
        
        Args:
            main_decoder (Decoder): Reference to the main decoder
                                    instance, used to fetch the predictors
            hypo_recombination (bool): Activates hypo recombination for the
                                       pre decoder 
            trg_test_file (string): Path to the plain text file with 
                                    the target sentences. Must have the
                                    same number of lines as the number
                                    of source sentences to decode. The 
                                    word order in the target sentences
                                    is not relevant for this predictor.
            accept_subsets (bool): If true, this predictor permits
                                       EOS even if the bag is not fully
                                       consumed yet
            accept_duplicates (bool): If true, counts are not updated
                                      when a word is consumed. This
                                      means that we allow a word in a
                                      bag to appear multiple times
            heuristic_scores_file (string): Path to the unigram scores 
                                            which are used if this 
                                            predictor estimates future
                                            costs
            collect_stats_strategy (string): best, full, or all. Defines 
                                             how unigram estimates are 
                                             collected for heuristic 
            heuristic_add_consumed (bool): Set to true to add the 
                                           difference between actual
                                           partial score and unigram
                                           estimates of consumed words
                                           to the predictor heuristic
            heuristic_add_remaining (bool): Set to true to add the sum
                                            of unigram scores of words
                                            remaining in the bag to the
                                            predictor heuristic
            equivalence_vocab (int): If positive, predictor states are
                                     considered equal if the the 
                                     remaining words within that vocab
                                     and OOVs regarding this vocab are
                                     the same. Only relevant when using
                                     hypothesis recombination
        """
        self.main_decoder = main_decoder
        self.pre_decoder = BeamDecoder(CLOSED_VOCAB_SCORE_NORM_NONE,
                                       main_decoder.max_len_factor,
                                       hypo_recombination, 10)
        self.pre_decoder.combine_posteriors = main_decoder.combine_posteriors
        super(BagOfWordsSearchPredictor,
              self).__init__(trg_test_file, accept_subsets, accept_duplicates,
                             heuristic_scores_file, collect_stats_strategy,
                             heuristic_add_consumed, heuristic_add_remaining,
                             diversity_heuristic_factor, equivalence_vocab)
        self.pre_mode = False

    def predict_next(self):
        """If in ``pre_mode``, pass through to super class. Otherwise,
        scan skeleton 
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).predict_next()
        if not self.bag:  # Empty bag
            return {utils.EOS_ID: 0.0}
        ret = {w: 0.0 for w in self.missing.iterkeys()}
        if self.accept_subsets:
            ret[utils.EOS_ID] = 0.0
        if self.skeleton_pos < len(self.skeleton):
            ret[self.skeleton[self.skeleton_pos]] = 0.0
        return ret

    def initialize(self, src_sentence):
        """If in ``pre_mode``, pass through to super class. Otherwise,
        initialize skeleton. 
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor,
                         self).initialize(src_sentence)
        self.pre_mode = True
        old_accept_subsets = self.accept_subsets
        old_accept_duplicates = self.accept_duplicates
        self.accept_subsets = True
        self.accept_duplicates = True
        self.pre_decoder.predictors = self.main_decoder.predictors
        self.pre_decoder.current_sen_id = self.main_decoder.current_sen_id - 1
        hypos = self.pre_decoder.decode(src_sentence)
        score = INF
        if not hypos:
            logging.warn("No hypothesis found by the pre decoder. Effectively "
                         "reducing bowsearch predictor to bow predictor.")
            self.skeleton = []
        else:
            self.skeleton = hypos[0].trgt_sentence
            score = hypos[0].total_score
            if self.skeleton and self.skeleton[-1] - -utils.EOS_ID:
                self.skeleton = self.skeleton[:-1]  # Remove EOS
        self.skeleton_pos = 0
        self.accept_subsets = old_accept_subsets
        self.accept_duplicates = old_accept_duplicates
        self._set_up_full_mode()
        logging.debug("BOW Skeleton (score=%f missing=%d): %s" %
                      (score, sum(self.missing.values()), self.skeleton))
        self.main_decoder.current_sen_id -= 1
        self.main_decoder.initialize_predictors(src_sentence)
        self.pre_mode = False

    def _set_up_full_mode(self):
        """This method initializes ``missing`` by using
        ``self.skeleton`` and ``self.full_bag`` and removes
        duplicates from ``self.skeleton``.
        """
        self.bag = dict(self.full_bag)
        missing = dict(self.full_bag)
        skeleton_no_duplicates = []
        for word in self.skeleton:
            if missing[word] > 0:
                missing[word] -= 1
                skeleton_no_duplicates.append(word)
        self.skeleton = skeleton_no_duplicates
        self.missing = {w: cnt for w, cnt in missing.iteritems() if cnt > 0}

    def consume(self, word):
        """Calls super class ``consume``. If not in ``pre_mode``,
        update skeleton info. 
        
        Args:
            word (int): Next word to consume
        """
        super(BagOfWordsSearchPredictor, self).consume(word)
        if self.pre_mode:
            return
        if (self.skeleton_pos < len(self.skeleton)
                and word == self.skeleton[self.skeleton_pos]):
            self.skeleton_pos += 1
        elif word in self.missing:
            self.missing[word] -= 1
            if self.missing[word] <= 0:
                del self.missing[word]

    def get_state(self):
        """If in pre_mode, state of this predictor is the current bag
        Otherwise, its the bag plus skeleton state
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).get_state()
        return self.bag, self.skeleton_pos, self.missing

    def set_state(self, state):
        """If in pre_mode, state of this predictor is the current bag
        Otherwise, its the bag plus skeleton state
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor, self).set_state(state)
        self.bag, self.skeleton_pos, self.missing = state

    def is_equal(self, state1, state2):
        """Returns true if the bag and the skeleton states are the same
        """
        if self.pre_mode:
            return super(BagOfWordsSearchPredictor,
                         self).is_equal(state1, state2)
        return super(BagOfWordsSearchPredictor,
                     self).is_equal(state1[0], state2[0])