def set_up_predictor(self, nmt_model_path):
        """Initializes the predictor with the given NMT model. Code 
        following ``blocks.machine_translation.main``. 
        """
        self.src_vocab_size = self.config['src_vocab_size']
        self.trgt_vocab_size = self.config['trg_vocab_size']
        self.nmt_model = NMTModel(self.config)
        self.nmt_model.set_up()
        loader = LoadNMTUtils(nmt_model_path, self.config['saveto'],
                              self.nmt_model.search_model)
        loader.load_weights()

        self.best_models = []
        self.val_bleu_curve = []
        self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
                if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
        if self.config['trg_sparse_feat_map']:
            self.trg_sparse_feat_map = self.config['trg_sparse_feat_map']
            self.search_algorithm = MyopticSparseSearch(
                samples=self.nmt_model.samples,
                trg_sparse_feat_map=self.trg_sparse_feat_map)
        else:
            self.trg_sparse_feat_map = FlatSparseFeatMap()
            self.search_algorithm = MyopticSearch(
                samples=self.nmt_model.samples)
        self.search_algorithm.compile()
 def set_up_decoder(self, nmt_specs):
     """This method sets up a list of NMT models and BeamSearch 
     instances, one for each model in the ensemble. Note that we do
     not use the ``BeamSearch.search`` method for ensemble decoding
     directly.
     
     Args:
         nmt_model_path (string):  Path to the NMT model file (.npz)
     """
     self.nmt_models = []
     self.beam_searches = []
     for nmt_model_path, nmt_config in nmt_specs:
         nmt_model = NMTModel(nmt_config)
         nmt_model.set_up()
         loader = LoadNMTUtils(nmt_model_path, nmt_config['saveto'],
                               nmt_model.search_model)
         loader.load_weights()
         self.nmt_models.append(nmt_model)
         self.beam_searches.append(BeamSearch(samples=nmt_model.samples))
 def set_up_decoder(self, nmt_model_path):
     """This method uses the NMT configuration in ``self.config`` to
     initialize the NMT model. This method basically corresponds to 
     ``blocks.machine_translation.main``.
     
     Args:
         nmt_model_path (string):  Path to the NMT model file (.npz)
     """
     self.nmt_model = NMTModel(self.config)
     self.nmt_model.set_up()
     loader = LoadNMTUtils(nmt_model_path, self.config['saveto'],
                           self.nmt_model.search_model)
     loader.load_weights()
     self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
             if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
     if self.config['trg_sparse_feat_map']:
         self.trg_sparse_feat_map = self.config['trg_sparse_feat_map']
         self.beam_search = SparseBeamSearch(
             samples=self.nmt_model.samples,
             trg_sparse_feat_map=self.trg_sparse_feat_map)
     else:
         self.trg_sparse_feat_map = FlatSparseFeatMap()
         self.beam_search = BeamSearch(samples=self.nmt_model.samples)
class BlocksNMTVanillaDecoder(Decoder):
    """Adaptor class for blocks.search.BeamSearch. We implement the
    ``Decoder`` class but ignore functionality for predictors or
    heuristics. Instead, we pass through decoding directly to the 
    blocks beam search module. This is fast, but breaks with the
    predictor framework. It can only be used for pure single system
    NMT decoding. Note that this decoder supports sparse feat maps
    on both source and target side.
    """
    def __init__(self, nmt_model_path, config, decoder_args):
        """Set up the NMT model used by the decoder.
        
        Args:
            nmt_model_path (string):  Path to the NMT model file (.npz)
            config (dict): NMT configuration
            decoder_args (object): Decoder configuration passed through
                                   from configuration API.
        """
        super(BlocksNMTVanillaDecoder, self).__init__(decoder_args)
        self.config = config
        self.set_up_decoder(nmt_model_path)
        self.src_eos = self.src_sparse_feat_map.word2dense(utils.EOS_ID)

    def set_up_decoder(self, nmt_model_path):
        """This method uses the NMT configuration in ``self.config`` to
        initialize the NMT model. This method basically corresponds to 
        ``blocks.machine_translation.main``.
        
        Args:
            nmt_model_path (string):  Path to the NMT model file (.npz)
        """
        self.nmt_model = NMTModel(self.config)
        self.nmt_model.set_up()
        loader = LoadNMTUtils(nmt_model_path, self.config['saveto'],
                              self.nmt_model.search_model)
        loader.load_weights()
        self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
                if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
        if self.config['trg_sparse_feat_map']:
            self.trg_sparse_feat_map = self.config['trg_sparse_feat_map']
            self.beam_search = SparseBeamSearch(
                samples=self.nmt_model.samples,
                trg_sparse_feat_map=self.trg_sparse_feat_map)
        else:
            self.trg_sparse_feat_map = FlatSparseFeatMap()
            self.beam_search = BeamSearch(samples=self.nmt_model.samples)

    def decode(self, src_sentence):
        """Decodes a single source sentence with the original blocks
        beam search decoder. Does not use predictors. Note that the
        score breakdowns in returned hypotheses are only on the 
        sentence level, not on the word level. For finer grained NMT
        scores you need to use the nmt predictor. ``src_sentence`` is a
        list of source word ids representing the source sentence without
        <S> or </S> symbols. As blocks expects to see </S>, this method
        adds it automatically.
        
        Args:
            src_sentence (list): List of source word ids without <S> or
                                 </S> which make up the source sentence
        
        Returns:
            list. A list of ``Hypothesis`` instances ordered by their
            score.
        """
        seq = self.src_sparse_feat_map.words2dense(
            utils.oov_to_unk(src_sentence,
                             self.config['src_vocab_size'])) + [self.src_eos]
        if self.src_sparse_feat_map.dim > 1:  # sparse src feats
            input_ = np.transpose(
                np.tile(seq, (self.config['beam_size'], 1, 1)), (2, 0, 1))
        else:  # word ids on the source side
            input_ = np.tile(seq, (self.config['beam_size'], 1))
        trans, costs = self.beam_search.search(
            input_values={self.nmt_model.sampling_input: input_},
            max_length=3 * len(src_sentence),
            eol_symbol=utils.EOS_ID,
            ignore_first_eol=True)
        hypos = []
        max_len = 0
        for idx in xrange(len(trans)):
            max_len = max(max_len, len(trans[idx]))
            hypo = Hypothesis(trans[idx], -costs[idx])
            hypo.score_breakdown = len(trans[idx]) * [[(0.0, 1.0)]]
            hypo.score_breakdown[0] = [(-costs[idx], 1.0)]
            hypos.append(hypo)
        self.apply_predictors_count = max_len * self.config['beam_size']
        return hypos

    def has_predictors(self):
        """Always returns true. """
        return True
Ejemplo n.º 5
0
def main(config,
         tr_stream,
         dev_stream,
         use_bokeh=False,
         slim_iteration_state=False,
         switch_controller=None,
         reset_epoch=False):
    """This method largely corresponds to the ``main`` method in the
    original Blocks implementation in blocks-examples and most of the
    code is copied from there. Following modifications have been made:
    
    - Support fixing word embedding during training
    - Dropout fix https://github.com/mila-udem/blocks-examples/issues/46
    - If necessary, add the exp3s extension
    
    Args:
        config (dict): NMT config
        tr_stream (DataStream): Training data stream
        dev_stream (DataStream): Validation data stream
        use_bokeh (bool): Whether to use bokeh for plotting
        slim_iteration_state (bool): Whether to store the full iteration
                                     state or only the epoch iterator
                                     without data stream state
        switch_controller (SourceSwitchController): Controlling strategy
                                                    if monolingual data
                                                    is used as well
        reset_epoch (bool): Set epoch_started in main loop status to
                            false. Sometimes required if you change
                            training parameters such as 
                            mono_data_integration
    """

    nmt_model = NMTModel(config)
    nmt_model.set_up(make_prunable=(args.prune_every > 0))

    # Set extensions
    logging.info("Initializing extensions")
    extensions = [
        #        FinishAfter(after_n_batches=config['finish_after']),
        FinishAfter(after_n_epochs=config['finish_after']),
        #        TrainingDataMonitoring([nmt_model.cost], after_batch=True),
        TrainingDataMonitoring([nmt_model.cost], after_epoch=True),
        #        Printing(after_batch=True),
        Printing(after_epoch=True),
        CheckpointNMT(config['saveto'],
                      slim_iteration_state,
                      every_n_batches=config['save_freq'])
    ]

    # Add early stopping based on bleu
    if config['bleu_script'] is not None:
        logging.info("Building bleu validator")
        extensions.append(
            AccValidator(
                nmt_model.sampling_input,
                #            BleuValidator(nmt_model.sampling_input,
                samples=nmt_model.samples,
                config=config,
                model=nmt_model.search_model,
                data_stream=dev_stream,
                normalize=config['normalized_bleu'],
                store_full_main_loop=config['store_full_main_loop'],
                #                          every_n_batches=config['bleu_val_freq']))
                every_n_epochs=config['bleu_val_freq']))

    if switch_controller:
        switch_controller.beam_search = BeamSearch(samples=nmt_model.samples)
        switch_controller.src_sentence = nmt_model.sampling_input
        extensions.append(switch_controller)

    # Reload model if necessary
    if config['reload']:
        extensions.append(
            LoadNMT(config['saveto'], slim_iteration_state, reset_epoch))

    # Plot cost in bokeh if necessary
    if use_bokeh and BOKEH_AVAILABLE:
        extensions.append(
            Plot('Decoding cost',
                 channels=[['decoder_cost_cost']],
                 after_batch=True))

    # Add an extension for correct handling of SIGTERM and SIGINT
    extensions.append(AlwaysEpochInterrupt(every_n_batches=1))

    # Set up training algorithm
    logging.info("Initializing training algorithm")
    # https://github.com/mila-udem/blocks-examples/issues/46
    train_params = nmt_model.cg.parameters
    # fs439: fix embeddings?
    if config['fix_embeddings']:
        train_params = []
        embedding_params = [
            'softmax1', 'softmax0', 'maxout_bias', 'embeddings', 'lookuptable',
            'transform_feedback'
        ]
        for p in nmt_model.cg.parameters:
            add_param = True
            for ann in p.tag.annotations:
                if ann.name in embedding_params:
                    logging.info("Do not train %s: %s" % (p, ann))
                    add_param = False
                    break
            if add_param:
                train_params.append(p)
    # Change cost=cost to cg.outputs[0] ?
    cost_func = nmt_model.cg.outputs[0] if config['dropout'] < 1.0 \
                                        else nmt_model.cost
    if config['step_rule'] in ['AdaGrad', 'Adam']:
        step_rule = eval(config['step_rule'])(learning_rate=args.learning_rate)
    else:
        step_rule = eval(config['step_rule'])()
    step_rule = CompositeRule(
        [StepClipping(config['step_clipping']), step_rule])
    if args.prune_every < 1:
        algorithm = GradientDescent(cost=cost_func,
                                    parameters=train_params,
                                    step_rule=step_rule)
    else:
        algorithm = PruningGradientDescent(
            prune_layer_configs=args.prune_layers.split(','),
            prune_layout_path=args.prune_layout_path,
            prune_n_steps=args.prune_n_steps,
            prune_every=args.prune_every,
            prune_reset_every=args.prune_reset_every,
            nmt_model=nmt_model,
            cost=cost_func,
            parameters=train_params,
            step_rule=step_rule)

    # Initialize main loop
    logging.info("Initializing main loop")
    main_loop = MainLoop(model=nmt_model.training_model,
                         algorithm=algorithm,
                         data_stream=tr_stream,
                         extensions=extensions)

    # Reset epoch
    if reset_epoch:
        main_loop.status['epoch_started'] = False

    # Train!
    main_loop.run()
class BlocksNMTPredictor(Predictor):
    """This is the neural machine translation predictor. The predicted
    posteriors are equal to the distribution generated by the decoder
    network in NMT. This predictor heavily relies on the NMT example in
    blocks. Note that this predictor cannot be used in combination with
    a target side sparse feature map. See 
    ``BlocksUnboundedNMTPredictor`` for that case.
    """
    def __init__(self, nmt_model_path, gnmt_beta, enable_cache, config):
        """Creates a new NMT predictor.
        
        Args:
            nmt_model_path (string):  Path to the NMT model file (.npz)
            gnmt_beta (float): If greater than 0.0, add a Google NMT
                               style coverage penalization term (Wu et
                               al., 2016) to the predictive scores
            enable_cache (bool):  The NMT predictor usually has a very
                                  limited vocabulary size, and a large
                                  number of UNKs in hypotheses. This
                                  enables reusing already computed
                                  predictor states for hypotheses which
                                  differ only by NMT OOV words.
            config (dict): NMT configuration
        
        Raises:
            ValueError. If a target sparse feature map is defined
        """
        super(BlocksNMTPredictor, self).__init__()
        self.gnmt_beta = gnmt_beta
        self.add_gnmt_coverage_term = gnmt_beta > 0.0
        self.config = copy.deepcopy(config)
        self.enable_cache = enable_cache
        self.set_up_predictor(nmt_model_path)
        self.src_eos = self.src_sparse_feat_map.word2dense(utils.EOS_ID)

    def set_up_predictor(self, nmt_model_path):
        """Initializes the predictor with the given NMT model. Code 
        following ``blocks.machine_translation.main``. 
        """
        self.src_vocab_size = self.config['src_vocab_size']
        self.trgt_vocab_size = self.config['trg_vocab_size']
        self.nmt_model = NMTModel(self.config)
        self.nmt_model.set_up()
        loader = LoadNMTUtils(nmt_model_path, self.config['saveto'],
                              self.nmt_model.search_model)
        loader.load_weights()

        self.best_models = []
        self.val_bleu_curve = []
        self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
                if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
        if self.config['trg_sparse_feat_map']:
            logging.fatal("Cannot use bounded vocabulary predictor with "
                          "a target sparse feature map. Ignoring...")
        self.search_algorithm = MyopticSearch(samples=self.nmt_model.samples)
        self.search_algorithm.compile()

    def initialize(self, src_sentence):
        """Runs the encoder network to create the source annotations
        for the source sentence. If the cache is enabled, empty the
        cache.
        
        Args:
            src_sentence (list): List of word ids without <S> and </S>
                                 which represent the source sentence.
        """

        self.reset()
        self.posterior_cache = SimpleTrie()
        self.states_cache = SimpleTrie()
        self.consumed = []
        seq = self.src_sparse_feat_map.words2dense(
            utils.oov_to_unk(src_sentence,
                             self.src_vocab_size)) + [self.src_eos]
        if self.src_sparse_feat_map.dim > 1:  # sparse src feats
            input_ = np.transpose(np.tile(seq, (1, 1, 1)), (2, 0, 1))
        else:  # word ids on the source side
            input_ = np.tile(seq, (1, 1))

        input_values = {self.nmt_model.sampling_input: input_}
        self.contexts, self.states, _ = self.search_algorithm.compute_initial_states_and_contexts(
            input_values)
        self.attention_records = (1 + len(src_sentence)) * [0.0]

    def is_history_cachable(self):
        """Returns true if cache is enabled and history contains UNK """
        if not self.enable_cache:
            return False
        for w in self.consumed:
            if w == utils.UNK_ID:
                return True
        return False

    def predict_next(self):
        """Uses cache or runs the decoder network to get the 
        distribution over the next target words.
        
        Returns:
            np array. Full distribution over the entire NMT vocabulary
            for the next target token.
        """
        use_cache = self.is_history_cachable()
        if use_cache:
            posterior = self.posterior_cache.get(self.consumed)
            if not posterior is None:
                logging.debug("Loaded NMT posterior from cache for %s" %
                              self.consumed)
                return self._add_gnmt_beta(posterior)
        # logprobs are negative log probs, i.e. greater than 0
        logprobs = self.search_algorithm.compute_logprobs(
            self.contexts, self.states)
        posterior = np.multiply(logprobs[0], -1.0)
        if use_cache:
            self.posterior_cache.add(self.consumed, posterior)
        return self._add_gnmt_beta(posterior)

    def _add_gnmt_beta(self, posterior):
        """Adds the GNMT coverage penalization term to EOS in 
        ``posterior``
        """
        if self.add_gnmt_coverage_term:
            posterior[utils.EOS_ID] += self.gnmt_beta * sum([
                np.log(max(0.0001, p))
                for p in self.attention_records if p < 1.0
            ])
        return posterior

    def get_unk_probability(self, posterior):
        """Returns the UNK probability defined by NMT. """
        return posterior[
            utils.UNK_ID] if len(posterior) > utils.UNK_ID else NEG_INF

    def consume(self, word):
        """Feeds back ``word`` to the decoder network. This includes 
        embedding of ``word``, running the attention network and update
        the recurrent decoder layer.
        """
        if word >= self.trgt_vocab_size:
            word = utils.UNK_ID
        self.consumed.append(word)
        use_cache = self.is_history_cachable()
        if use_cache:
            s = self.states_cache.get(self.consumed)
            if not s is None:
                logging.debug("Loaded NMT decoder states from cache for %s" %
                              self.consumed)
                self.states = copy.deepcopy(s)
                return
        self.states.update(
            self.search_algorithm.compute_next_states(self.contexts,
                                                      self.states, [word]))
        if use_cache:
            self.states_cache.add(self.consumed, copy.deepcopy(self.states))
        if self.add_gnmt_coverage_term:  # Keep track of attentions
            for pos, att in enumerate(self.states['weights'][0]):
                self.attention_records[pos] += att

    def get_state(self):
        """The NMT predictor state consists of the decoder network 
        state, and (for caching) the current history of consumed words
        """
        return self.states, self.consumed, self.attention_records

    def set_state(self, state):
        """Set the NMT predictor state. """
        self.states, self.consumed, self.attention_records = state

    def reset(self):
        """Deletes the source side annotations and decoder state. """
        self.contexts = None
        self.states = None

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        _, consumed1, _ = state1
        _, consumed2, _ = state2
        return consumed1 == consumed2
class BlocksUnboundedNMTPredictor(BlocksNMTPredictor,
                                  UnboundedVocabularyPredictor):
    """This is a version of the NMT predictor which assumes an 
    unbounded vocabulary. Therefore, this predictor can only be used 
    when other predictors (like fst) define the words to score. Using
    this predictor is mandatory when a target sparse feature map is
    provided.
    """
    def __init__(self, nmt_model_path, gnmt_beta, config):
        """Creates a new NMT predictor with unbounded vocabulary.
        
        Args:
            nmt_model_path (string):  Path to the NMT model file (.npz)
            config (dict): NMT configuration, 
        """
        super(BlocksUnboundedNMTPredictor,
              self).__init__(nmt_model_path, gnmt_beta, False, config)

    def set_up_predictor(self, nmt_model_path):
        """Initializes the predictor with the given NMT model. Code 
        following ``blocks.machine_translation.main``. 
        """
        self.src_vocab_size = self.config['src_vocab_size']
        self.trgt_vocab_size = self.config['trg_vocab_size']
        self.nmt_model = NMTModel(self.config)
        self.nmt_model.set_up()
        loader = LoadNMTUtils(nmt_model_path, self.config['saveto'],
                              self.nmt_model.search_model)
        loader.load_weights()

        self.best_models = []
        self.val_bleu_curve = []
        self.src_sparse_feat_map = self.config['src_sparse_feat_map'] \
                if self.config['src_sparse_feat_map'] else FlatSparseFeatMap()
        if self.config['trg_sparse_feat_map']:
            self.trg_sparse_feat_map = self.config['trg_sparse_feat_map']
            self.search_algorithm = MyopticSparseSearch(
                samples=self.nmt_model.samples,
                trg_sparse_feat_map=self.trg_sparse_feat_map)
        else:
            self.trg_sparse_feat_map = FlatSparseFeatMap()
            self.search_algorithm = MyopticSearch(
                samples=self.nmt_model.samples)
        self.search_algorithm.compile()

    def predict_next(self, words):
        """Uses cache or runs the decoder network to get the 
        distribution over the next target words.
        
        Returns:
            np array. Full distribution over the entire NMT vocabulary
            for the next target token.
        """
        logprobs = self.search_algorithm.compute_logprobs(
            self.contexts, self.states)
        if self.trg_sparse_feat_map.dim > 1:
            return {
                w: -sparse.dense_euclidean2(
                    logprobs[0], self.trg_sparse_feat_map.word2dense(w))
                for w in words
            }
        else:
            # logprobs are negative log probs, i.e. greater than 0
            posterior = np.multiply(logprobs[0], -1.0)
            return {w: posterior[w] for w in words}

    def get_unk_probability(self, posterior):
        """Returns negative inf as this is a unbounded predictor. """
        return NEG_INF

    def consume(self, word):
        """Feeds back ``word`` to the decoder network. This includes 
        embedding of ``word``, running the attention network and update
        the recurrent decoder layer.
        """
        if word >= self.trgt_vocab_size:
            word = utils.UNK_ID
        self.consumed.append(word)
        self.states.update(
            self.search_algorithm.compute_next_states(
                self.contexts, self.states,
                [self.trg_sparse_feat_map.word2dense(word)]))

    def is_equal(self, state1, state2):
        """Returns true if the history is the same """
        _, consumed1 = state1
        _, consumed2 = state2
        return consumed1 == consumed2