Пример #1
0
    def __init__(self, src_vocab_size, trg_vocab_size, model_name, 
                 hparams_set_name, checkpoint_dir, single_cpu_thread,
                 nizza_unk_id=None):
        """Initializes a nizza predictor.

        Args:
            src_vocab_size (int): Source vocabulary size (called inputs_vocab_size
                in nizza)
            trg_vocab_size (int): Target vocabulary size (called targets_vocab_size
                in nizza)
            model_name (string): Name of the nizza model
            hparams_set_name (string): Name of the nizza hyper-parameter set
            checkpoint_dir (string): Path to the Nizza checkpoint directory. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            single_cpu_thread (bool): If true, prevent tensorflow from
                                      doing multithreading.
            nizza_unk_id (int): If set, use this as UNK id. Otherwise, the
                nizza is assumed to have no UNKs

        Raises:
            IOError if checkpoint file not found.
        """
        super(BaseNizzaPredictor, self).__init__()
        if not os.path.isfile("%s/checkpoint" % checkpoint_dir):
            logging.fatal("Checkpoint file %s/checkpoint not found!" 
                          % checkpoint_dir)
            raise IOError
        self._single_cpu_thread = single_cpu_thread
        self._checkpoint_dir = checkpoint_dir
        self._nizza_unk_id = nizza_unk_id
        predictor_graph = tf.Graph()
        with predictor_graph.as_default() as g:
            hparams = registry.get_registered_hparams_set(hparams_set_name)
            hparams.add_hparam("inputs_vocab_size", src_vocab_size)
            hparams.add_hparam("targets_vocab_size", trg_vocab_size)
            run_config = tf.contrib.learn.RunConfig()
            run_config = run_config.replace(model_dir=checkpoint_dir)
            model = registry.get_registered_model(model_name, hparams, run_config)
            self._inputs_var = tf.placeholder(dtype=tf.int32, shape=[None],
                                              name="sgnmt_inputs")
            self._targets_var = tf.placeholder(dtype=tf.int32, shape=[None], 
                                               name="sgnmt_targets")
            features = {"inputs": tf.expand_dims(self._inputs_var, 0), 
                        "targets": tf.expand_dims(self._targets_var, 0)}
            mode = tf.estimator.ModeKeys.PREDICT
            self.precomputed = model.precompute(features, mode, hparams)
            self.log_probs = tf.squeeze(
                model.predict_next_word(features, hparams, self.precomputed), 0)
            self.mon_sess = self.create_session(self._checkpoint_dir)
Пример #2
0
def get_hparams():
    """Gets the hyperparameters from command line arguments.

  Returns:
    An HParams instance.

  Throws:
    ValueError if FLAGS.hparams_set could not be found
    in the registry.
  """
    hparams = registry.get_registered_hparams_set(FLAGS.hparams_set)
    hparams.add_hparam("inputs_vocab_size", FLAGS.inputs_vocab_size)
    hparams.add_hparam("targets_vocab_size", FLAGS.targets_vocab_size)
    hparams.parse(FLAGS.hparams)
    return hparams
Пример #3
0
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 model_name,
                 hparams_set_name,
                 checkpoint_dir,
                 single_cpu_thread,
                 nizza_unk_id=None):
        """Initializes a nizza predictor.

        Args:
            src_vocab_size (int): Source vocabulary size (called inputs_vocab_size
                in nizza)
            trg_vocab_size (int): Target vocabulary size (called targets_vocab_size
                in nizza)
            model_name (string): Name of the nizza model
            hparams_set_name (string): Name of the nizza hyper-parameter set
            checkpoint_dir (string): Path to the Nizza checkpoint directory. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            single_cpu_thread (bool): If true, prevent tensorflow from
                                      doing multithreading.
            nizza_unk_id (int): If set, use this as UNK id. Otherwise, the
                nizza is assumed to have no UNKs

        Raises:
            IOError if checkpoint file not found.
        """
        super(BaseNizzaPredictor, self).__init__()
        if not os.path.isfile("%s/checkpoint" % checkpoint_dir):
            logging.fatal("Checkpoint file %s/checkpoint not found!" %
                          checkpoint_dir)
            raise IOError
        self._single_cpu_thread = single_cpu_thread
        self._checkpoint_dir = checkpoint_dir
        self._nizza_unk_id = nizza_unk_id
        predictor_graph = tf.Graph()
        with predictor_graph.as_default() as g:
            hparams = registry.get_registered_hparams_set(hparams_set_name)
            hparams.add_hparam("inputs_vocab_size", src_vocab_size)
            hparams.add_hparam("targets_vocab_size", trg_vocab_size)
            run_config = tf.contrib.learn.RunConfig()
            run_config = run_config.replace(model_dir=checkpoint_dir)
            model = registry.get_registered_model(model_name, hparams,
                                                  run_config)
            self._inputs_var = tf.placeholder(dtype=tf.int32,
                                              shape=[None],
                                              name="sgnmt_inputs")
            self._targets_var = tf.placeholder(dtype=tf.int32,
                                               shape=[None],
                                               name="sgnmt_targets")
            features = {
                "inputs": tf.expand_dims(self._inputs_var, 0),
                "targets": tf.expand_dims(self._targets_var, 0)
            }
            mode = tf.estimator.ModeKeys.PREDICT
            self.precomputed = model.precompute(features, mode, hparams)
            self.log_probs = tf.squeeze(
                model.predict_next_word(features, hparams, self.precomputed),
                0)
            self.mon_sess = self.create_session(self._checkpoint_dir)
Пример #4
0
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 model_name,
                 hparams_set_name,
                 checkpoint_dir,
                 single_cpu_thread,
                 alpha,
                 beta,
                 shortlist_strategies,
                 trg2src_model_name="",
                 trg2src_hparams_set_name="",
                 trg2src_checkpoint_dir="",
                 max_shortlist_length=0,
                 min_id=0,
                 nizza_unk_id=None):
        """Initializes a nizza predictor.

        Args:
            src_vocab_size (int): Source vocabulary size (called inputs_vocab_size
                in nizza)
            trg_vocab_size (int): Target vocabulary size (called targets_vocab_size
                in nizza)
            model_name (string): Name of the nizza model
            hparams_set_name (string): Name of the nizza hyper-parameter set
            checkpoint_dir (string): Path to the Nizza checkpoint directory. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            single_cpu_thread (bool): If true, prevent tensorflow from
                                      doing multithreading.
            alpha (float): Score for each matching word
            beta (float): Penalty for each uncovered word at the end
            shortlist_strategies (string): Comma-separated list of shortlist
                strategies.
            trg2src_model_name (string): Name of the target2source nizza model
            trg2src_hparams_set_name (string): Name of the nizza hyper-parameter set
                                     for the target2source model
            trg2src_checkpoint_dir (string): Path to the Nizza checkpoint directory
                                     for the target2source model. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            max_shortlist_length (int): If a shortlist exceeds this limit,
                initialize the initial coverage with 1 at this position. If
                zero, do not apply any limit
            min_id (int): Do not use IDs below this threshold (filters out most
                frequent words).
            nizza_unk_id (int): If set, use this as UNK id. Otherwise, the
                nizza is assumed to have no UNKs

        Raises:
            IOError if checkpoint file not found.
        """
        super(LexNizzaPredictor, self).__init__(src_vocab_size,
                                                trg_vocab_size,
                                                model_name,
                                                hparams_set_name,
                                                checkpoint_dir,
                                                single_cpu_thread,
                                                nizza_unk_id=nizza_unk_id)
        self.alpha = alpha
        self.alpha_is_zero = alpha == 0.0
        self.beta = beta
        self.shortlist_strategies = utils.split_comma(shortlist_strategies)
        self.max_shortlist_length = max_shortlist_length
        self.min_id = min_id
        if trg2src_checkpoint_dir:
            self.use_trg2src = True
            predictor_graph = tf.Graph()
            with predictor_graph.as_default() as g:
                hparams = registry.get_registered_hparams_set(
                    trg2src_hparams_set_name)
                hparams.add_hparam("inputs_vocab_size", trg_vocab_size)
                hparams.add_hparam("targets_vocab_size", src_vocab_size)
                run_config = tf.contrib.learn.RunConfig()
                run_config = run_config.replace(
                    model_dir=trg2src_checkpoint_dir)
                model = registry.get_registered_model(trg2src_model_name,
                                                      hparams, run_config)
                features = {
                    "inputs": tf.expand_dims(tf.range(trg_vocab_size), 0)
                }
                mode = tf.estimator.ModeKeys.PREDICT
                trg2src_lex_logits = model.precompute(features, mode, hparams)
                # Precompute trg2src partitions
                partitions = tf.reduce_logsumexp(trg2src_lex_logits, axis=-1)
                self._trg2src_src_words_var = tf.placeholder(
                    dtype=tf.int32,
                    shape=[None],
                    name="sgnmt_trg2src_src_words")
                # trg2src_lex_logits has shape [1, trg_vocab_size, src_vocab_size]
                self.trg2src_logits = tf.gather(
                    tf.transpose(trg2src_lex_logits[0, :, :]),
                    self._trg2src_src_words_var)
                # trg2src_logits has shape [len(src_words), trg_vocab_size]
                self.trg2src_mon_sess = self.create_session(
                    trg2src_checkpoint_dir)
                logging.debug("Precomputing lexnizza trg2src partitions...")
                self.trg2src_partitions = self.trg2src_mon_sess.run(partitions)
        else:
            self.use_trg2src = False
            logging.warn("No target-to-source model specified for lexnizza.")
Пример #5
0
    def __init__(self, src_vocab_size, trg_vocab_size, model_name, 
                 hparams_set_name, checkpoint_dir, single_cpu_thread,
                 alpha, beta, shortlist_strategies,
                 trg2src_model_name="", trg2src_hparams_set_name="",
                 trg2src_checkpoint_dir="",
                 max_shortlist_length=0,
                 min_id=0,
                 nizza_unk_id=None):
        """Initializes a nizza predictor.

        Args:
            src_vocab_size (int): Source vocabulary size (called inputs_vocab_size
                in nizza)
            trg_vocab_size (int): Target vocabulary size (called targets_vocab_size
                in nizza)
            model_name (string): Name of the nizza model
            hparams_set_name (string): Name of the nizza hyper-parameter set
            checkpoint_dir (string): Path to the Nizza checkpoint directory. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            single_cpu_thread (bool): If true, prevent tensorflow from
                                      doing multithreading.
            alpha (float): Score for each matching word
            beta (float): Penalty for each uncovered word at the end
            shortlist_strategies (string): Comma-separated list of shortlist
                strategies.
            trg2src_model_name (string): Name of the target2source nizza model
            trg2src_hparams_set_name (string): Name of the nizza hyper-parameter set
                                     for the target2source model
            trg2src_checkpoint_dir (string): Path to the Nizza checkpoint directory
                                     for the target2source model. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            max_shortlist_length (int): If a shortlist exceeds this limit,
                initialize the initial coverage with 1 at this position. If
                zero, do not apply any limit
            min_id (int): Do not use IDs below this threshold (filters out most
                frequent words).
            nizza_unk_id (int): If set, use this as UNK id. Otherwise, the
                nizza is assumed to have no UNKs

        Raises:
            IOError if checkpoint file not found.
        """
        super(LexNizzaPredictor, self).__init__(
                src_vocab_size, trg_vocab_size, model_name, hparams_set_name, 
                checkpoint_dir, single_cpu_thread, nizza_unk_id=nizza_unk_id)
        self.alpha = alpha
        self.alpha_is_zero = alpha == 0.0
        self.beta = beta
        self.shortlist_strategies = utils.split_comma(shortlist_strategies)
        self.max_shortlist_length = max_shortlist_length
        self.min_id = min_id
        if trg2src_checkpoint_dir:
            self.use_trg2src = True
            predictor_graph = tf.Graph()
            with predictor_graph.as_default() as g:
                hparams = registry.get_registered_hparams_set(trg2src_hparams_set_name)
                hparams.add_hparam("inputs_vocab_size", trg_vocab_size)
                hparams.add_hparam("targets_vocab_size", src_vocab_size)
                run_config = tf.contrib.learn.RunConfig()
                run_config = run_config.replace(model_dir=trg2src_checkpoint_dir)
                model = registry.get_registered_model(trg2src_model_name, hparams, run_config)
                features = {"inputs": tf.expand_dims(tf.range(trg_vocab_size), 0)}
                mode = tf.estimator.ModeKeys.PREDICT
                trg2src_lex_logits = model.precompute(features, mode, hparams)
                # Precompute trg2src partitions
                partitions = tf.reduce_logsumexp(trg2src_lex_logits, axis=-1)
                self._trg2src_src_words_var = tf.placeholder(dtype=tf.int32, shape=[None],
                                                  name="sgnmt_trg2src_src_words")
                # trg2src_lex_logits has shape [1, trg_vocab_size, src_vocab_size]
                self.trg2src_logits = tf.gather(tf.transpose(trg2src_lex_logits[0, :, :]), self._trg2src_src_words_var)
                # trg2src_logits has shape [len(src_words), trg_vocab_size]
                self.trg2src_mon_sess = self.create_session(trg2src_checkpoint_dir)
                logging.debug("Precomputing lexnizza trg2src partitions...")
                self.trg2src_partitions = self.trg2src_mon_sess.run(partitions)
        else:
            self.use_trg2src = False
            logging.warn("No target-to-source model specified for lexnizza.")