示例#1
0
def load_model(model, ckpt, session, name):
    start_time = time.time()
    model.saver.restore(session, ckpt)
    session.run(tf.tables_initializer())
    utils.print_out("  loaded %s model parameters from %s, time %.2fs" %
                    (name, ckpt, time.time() - start_time))
    return model
def ensure_compatible_hparams(hparams, default_hparams, flags):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, flags.hparams_path, verbose=not flags.chat)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Make sure that the loaded model has latest values for the below keys
    updated_keys = [
        "out_dir", "num_gpus", "test_prefix", "beam_width",
        "length_penalty_weight", "num_train_steps", "number_token",
        "name_token", "gpe_token", "UNAME", "TOKEN"
    ]
    for key in updated_keys:
        if key in default_config and getattr(hparams, key) != default_config[key]:
            if not flags.chat:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                            (key, str(getattr(hparams, key)), str(default_config[key])))
            setattr(hparams, key, default_config[key])
    return hparams
def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        subword_option,
                        beam_width,
                        tgt_eos,
                        num_translations_per_input=1,
                        decode=True):
  """Decode a test set and compute a score according to the evaluation task."""
  # Decode
  if decode:
    utils.print_out("  decoding to output %s." % trans_file)

    start_time = time.time()
    num_sentences = 0
    with codecs.getwriter("utf-8")(
        tf.gfile.GFile(trans_file, mode="wb")) as trans_f:
      trans_f.write("")  # Write empty string to ensure file is created.

      num_translations_per_input = max(
          min(num_translations_per_input, beam_width), 1)
      while True:
        try:
          nmt_outputs, _ = model.decode(sess)
          if beam_width == 0:
            nmt_outputs = np.expand_dims(nmt_outputs, 0)

          batch_size = nmt_outputs.shape[1]
          num_sentences += batch_size

          for sent_id in range(batch_size):
            for beam_id in range(num_translations_per_input):
              translation = get_translation(
                  nmt_outputs[beam_id],
                  sent_id,
                  tgt_eos=tgt_eos,
                  subword_option=subword_option)
              trans_f.write((translation + b"\n").decode("utf-8"))
        except tf.errors.OutOfRangeError:
          utils.print_time(
              "  done, num sentences %d, num translations per input %d" %
              (num_sentences, num_translations_per_input), start_time)
          break

  # Evaluation
  evaluation_scores = {}
  if ref_file and tf.gfile.Exists(trans_file):
    for metric in metrics:
      score = evaluation_utils.evaluate(
          ref_file,
          trans_file,
          metric,
          subword_option=subword_option)
      evaluation_scores[metric] = score
      utils.print_out("  %s %s: %.1f" % (metric, name, score))

  return evaluation_scores
示例#4
0
def _cell_list(unit_type,
               num_units,
               num_layers,
               num_residual_layers,
               forget_bias,
               dropout,
               mode,
               num_gpus,
               base_gpu=0,
               verbose=True):
    """Create a list of RNN cells."""
    # Multi-GPU
    cell_list = []
    for i in range(num_layers):
        if verbose:
            utils.print_out("  cell %d" % i, new_line=False)
        dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0  # Disable dropout outside training.
        single_cell = _single_cell(
            unit_type=unit_type,
            num_units=num_units,
            forget_bias=forget_bias,
            dropout=dropout,
            residual_connection=(i >= num_layers - num_residual_layers
                                 ),  # Apply residual wrapper to last layers.
            device_str=get_device_str(
                i + base_gpu, num_gpus),  # Parallelize computation over GPUs
            verbose=verbose  # Whether to print to stdout
        )
        if verbose:
            utils.print_out("")  # create new line
        cell_list.append(single_cell)

    return cell_list
def create_new_vocab_file(vocab_file):
  """Creates a new vocabulary file prepending three new tokens:
  (1) <unk> for unknown tag, (2) <s> for start of sentence tag, and (3) </s> for end of
  sentence tag."""
  vocab = []
  with codecs.getreader("utf-8")(tf.io.gfile.GFile(vocab_file, "rb")) as f:
    vocab_size = 0
    for word in f:
      vocab_size += 1
      vocab.append(word.strip())

  if tf.io.gfile.exists(vocab_file):
    utils.print_out("# Vocab file %s exists" % vocab_file)
    assert len(vocab) >= 3
    (unk, sos, eos) = ("<unk>", "<s>", "</s>")
    if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos:
      utils.print_out("The first 3 vocab words [%s, %s, %s]"
                      " are not [%s, %s, %s]" %
                      (vocab[0], vocab[1], vocab[2], unk, sos, eos))
      vocab = [unk, sos, eos] + vocab
      vocab_size += 3
      new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file))
      with codecs.getwriter("utf-8")(
          tf.io.gfile.GFile(new_vocab_file, "wb")) as f:
        for word in vocab:
          f.write("%s\n" % word)
      vocab_file = new_vocab_file
  else:
    raise ValueError("vocab_file '%s' does not exist." % vocab_file)
  return vocab_file
示例#6
0
def _cell_list(unit_type, num_units, num_layers, num_residual_layers,
               forget_bias, dropout, mode, num_gpus, base_gpu=0,
               single_cell_fn=None, num_proj=None, num_cells=1):
  """Create a list of RNN cells."""
  if not single_cell_fn:
    single_cell_fn = _single_cell

  # Multi-GPU
  cell_list = []
  for i in range(num_layers):
    utils.print_out("  cell %d" % i, new_line=False)
    single_cell = single_cell_fn(
        unit_type=unit_type,
        num_units=num_units,
        forget_bias=forget_bias,
        dropout=dropout,
        mode=mode,
        residual_connection=(i >= num_layers - num_residual_layers),
        device_str=get_device_str(i + base_gpu, num_gpus),
        num_proj=num_proj,
        num_layers=num_layers
    )
    utils.print_out("")
    cell_list.append(single_cell)

  return cell_list
示例#7
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path=""):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # Set num encoder/decoder layers (for old checkpoints)
    if hasattr(hparams, "num_layers"):
        if not hasattr(hparams, "num_encoder_layers"):
            hparams.add_hparam("num_encoder_layers", hparams.num_layers)
        if not hasattr(hparams, "num_decoder_layers"):
            hparams.add_hparam("num_decoder_layers", hparams.num_layers)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    overwritten_keys = None
    if getattr(default_hparams, "override_loaded_hparams", None):
        overwritten_keys = default_config.keys()

    if overwritten_keys is not None:
        for key in overwritten_keys:
            if getattr(hparams, key) != default_config[key]:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                                (key, str(getattr(
                                    hparams, key)), str(default_config[key])))
                setattr(hparams, key, default_config[key])
    return hparams
示例#8
0
def check_vocab(vocab_file, out_dir, sos=None, eos=None, unk=None):
    """Check if vocab_file doesn't exist, create from corpus_file."""
    if tf.gfile.Exists(vocab_file):
        utils.print_out("# Vocab file %s exists" % vocab_file)
        vocab = []
        with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
            vocab_size = 0
            for word in f:
                vocab_size += 1
                vocab.append(word.strip())

        # Verify if the vocab starts with unk, sos, eos
        # If not, prepend those tokens & generate a new vocab file
        if not unk: unk = UNK
        if not sos: sos = SOS
        if not eos: eos = EOS
        assert len(vocab) >= 3
        if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos:
            utils.print_out("The first 3 vocab words [%s, %s, %s]"
                            " are not [%s, %s, %s]" %
                            (vocab[0], vocab[1], vocab[2], unk, sos, eos))
            vocab = [unk, sos, eos] + vocab
            vocab_size += 3
            new_vocab_file = os.path.join(out_dir,
                                          os.path.basename(vocab_file))
            with codecs.getwriter("utf-8")(tf.gfile.GFile(
                    new_vocab_file, "wb")) as f:
                for word in vocab:
                    f.write("%s\n" % word)
            vocab_file = new_vocab_file
    else:
        raise ValueError("vocab_file does not exist: " + vocab_file)

    vocab_size = len(vocab)
    return vocab_size, vocab_file
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  # %d" % decode_id)

    iterator_feed_dict = {
        iterator_src_placeholder: [src_data[decode_id]],
        iterator_batch_size_placeholder: 1,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    nmt_outputs, attention_summary = model.decode(sess)

    if hparams.beam_width > 0:
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    translation = nmt_utils.get_translation(
        nmt_outputs,
        sent_id=0,
        tgt_eos=hparams.eos,
        subword_option=hparams.subword_option)
    #utils.print_out(b"    src: " + src_data[decode_id])
    #utils.print_out(b"    ref: " + tgt_data[decode_id])
    #utils.print_out(b"    nmt: " + translation)

    # Summary
    if attention_summary is not None:
        summary_writer.add_summary(attention_summary, global_step)
示例#10
0
    def _get_learning_rate_decay(self, hparams):
        """Get learning rate decay."""
        if hparams.learning_rate_decay_scheme in ["luong", "luong10"]:
            start_factor = 2
            start_decay_step = int(hparams.num_train_steps / start_factor)
            decay_factor = 0.5

            # decay 5 times
            if hparams.learning_rate_decay_scheme == "luong":
                decay_steps = int(hparams.num_train_steps / (5 * start_factor))
            # decay 10 times
            elif hparams.learning_rate_decay_scheme == "luong10":
                decay_steps = int(hparams.num_train_steps /
                                  (10 * start_factor))
        else:
            start_decay_step = hparams.start_decay_step
            decay_steps = hparams.decay_steps
            decay_factor = hparams.decay_factor
        utils.print_out(
            "  decay_scheme=%s, start_decay_step=%d, decay_steps %d, "
            "decay_factor %g" %
            (hparams.learning_rate_decay_scheme, hparams.start_decay_step,
             hparams.decay_steps, hparams.decay_factor))

        return tf.cond(self.global_step < start_decay_step,
                       lambda: self.learning_rate,
                       lambda: tf.train.exponential_decay(self.learning_rate, (
                           self.global_step - start_decay_step),
                                                          decay_steps,
                                                          decay_factor,
                                                          staircase=True),
                       name="learning_rate_decay_cond")
示例#11
0
def load_model(model, ckpt, session, name, verbose=True):
    # Load the model from checkpoint
    start_time = time.time()
    model.saver.restore(session, ckpt)
    session.run(tf.tables_initializer())
    if verbose:
        utils.print_out("  loaded %s model parameters from %s, time %.2fs" %
                        (name, ckpt, time.time() - start_time))
    return model
示例#12
0
def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        bpe_delimiter,
                        beam_width,
                        tgt_sos,
                        tgt_eos,
                        decode=True):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s." % trans_file)

        start_time = time.time()
        num_sentences = 0
        with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                      mode="wb")) as trans_f:
            trans_f.write("")  # Write empty string to ensure file is created.

            while True:
                try:
                    nmt_outputs, _, _ = model.decode(sess)

                    if beam_width > 0:
                        # get the top translation.
                        nmt_outputs = nmt_outputs[0]

                    num_sentences += len(nmt_outputs)
                    for sent_id in range(len(nmt_outputs)):
                        translation = get_translation(
                            nmt_outputs,
                            sent_id,
                            tgt_sos=tgt_sos,
                            tgt_eos=tgt_eos,
                            bpe_delimiter=bpe_delimiter)
                        trans_f.write((translation + b"\n").decode("utf-8"))
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d" % num_sentences, start_time)
                    break

    # Evaluation
    evaluation_scores = {}
    if ref_file and tf.gfile.Exists(trans_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(ref_file,
                                              trans_file,
                                              metric,
                                              bpe_delimiter=bpe_delimiter)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
def create_model_Alveo(model, session, name):
    """Create translation model and initialize or load parameters in session."""
    start_time = time.time()
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    utils.print_out("  created %s model with fresh parameters, time %.2fs" %
                    (name, time.time() - start_time))

    global_step = model.global_step.eval(session=session)
    return model, global_step
def decode_and_evaluate(name,
                        model,
                        sess,
                        output_file,
                        reference_file,
                        metrics,
                        bpe_delimiter,
                        beam_width,
                        eos,
                        number_token=None,
                        name_token=None,
                        decode=True):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s." % output_file)
        start_time = time.time()
        num_sentences = 0
        with tf.gfile.GFile(output_file, mode="w+") as out_f:
            out_f.write("")  # Write empty string to ensure file is created.

            while True:
                try:
                    # Get the response(s) for each input in the batch (whole file in this case)
                    # ToDo: adapt for architectures
                    outputs, infer_summary = model.decode(sess)

                    if beam_width > 0:
                        # Get the top response if we used beam_search
                        outputs = outputs[0]

                    num_sentences += len(outputs)
                    # Iterate over the outputs an write them to file
                    for sent_id in range(len(outputs)):
                        response = postprocess_output(outputs, sent_id, eos,
                                                      bpe_delimiter,
                                                      number_token, name_token)
                        out_f.write("%s\n" % response)
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d" % num_sentences, start_time)
                    break

    # Evaluation
    evaluation_scores = {}
    if reference_file and tf.gfile.Exists(output_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(ref_file=reference_file,
                                              trans_file=output_file,
                                              metric=metric,
                                              bpe_delimiter=bpe_delimiter)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
示例#15
0
def train_main(flags, default_hparams, train_fn, target_session=""):
    """Run main."""
    # Job
    jobid = flags.jobid
    num_workers = flags.num_workers
    utils.print_out("# Job id %d" % jobid)

    # Random
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed + jobid)
        np.random.seed(random_seed + jobid)

    ## Train / Decode
    root = default_hparams.out_model_info.split('/')
    parent_path = ''
    for i in range(len(root) - 1):
        parent_path += root[i] + '/'

    flags.out_dir = parent_path + flags.out_dir
    out_dir = flags.out_dir
    default_hparams.out_dir = out_dir
    if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir)

    # Load hparams.

    hparams = create_or_load_hparams(out_dir,
                                     default_hparams,
                                     flags.hparams_path,
                                     save_hparams=(jobid == 0))
    hparams.out_dir = out_dir
    #

    print(hparams.num_units)
    print(hparams.dropout)
    print(hparams.attention)
    print(hparams.train_src)
    print(hparams.dev_src)

    print(hparams.out_model_info)
    print(hparams.out_dir)
    #print(hparams.out_hparam)
    #print(hparams.best_metric_path)

    train_fn(hparams, target_session=target_session)

    out_model_info = flags.out_model_info
    f = open(out_model_info, 'w')
    f.write(out_dir)
    f.close()
示例#16
0
def create_or_load_model(model, model_dir, session, name):
  """Create translation model and initialize or load parameters in session."""
  latest_ckpt = tf.train.latest_checkpoint(model_dir)
  if latest_ckpt:
    model = load_model(model, latest_ckpt, session, name)
  else:
    start_time = time.time()
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    utils.print_out("  created %s model with fresh parameters, time %.2fs" %
                    (name, time.time() - start_time))

  global_step = model.global_step.eval(session=session)
  return model, global_step
示例#17
0
 def _get_infer_maximum_iterations(self, hparams, source_sequence_length):
     """Maximum decoding steps at inference time."""
     if hparams.tgt_max_len_infer:
         maximum_iterations = hparams.tgt_max_len_infer
         utils.print_out("  decoding maximum_iterations %d" %
                         maximum_iterations)
     else:
         # TODO(thangluong): add decoding_length_factor flag
         decoding_length_factor = 2.0
         max_encoder_length = tf.reduce_max(source_sequence_length)
         maximum_iterations = tf.to_int32(
             tf.round(
                 tf.to_float(max_encoder_length) * decoding_length_factor))
     return maximum_iterations
示例#18
0
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder, iterator_batch_size_placeholder,
                   summary_writer):
    """
    Pick a random sentence and decode it.
    Args:
            iterator_src_placeholder, iterator_batch_size_placeholder: used to initialize the model
    """
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  Decoding sentence %d" % decode_id)
    # Format the random sentence into a batch_size of 1 format.
    sentence = [src_data[decode_id]]
    # Create the feed-dict for the iterator
    iterator_feed_dict = {
        iterator_src_placeholder: sentence,
        iterator_batch_size_placeholder: 1
    }
    # Initialize the iterator
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    # Get the response. The summary is only used in attention models, which we do not use atm
    response, attention_summary = model.decode(sess)

    if hparams.beam_width > 0:
        response = response[0]
    # Postprocess the response
    response = chatbot_utils.postprocess_output(response, sentence_id=0, eos=hparams.eos,
                                                bpe_delimiter=hparams.bpe_delimiter,
                                                number_token=hparams.number_token, name_token=hparams.name_token)

    # ToDo: Add attention summary here if deciding to use attention
    # Add the print to check the model's progress
    utils.print_out("    src: %s" % src_data[decode_id])
    utils.print_out("    ref: %s" % tgt_data[decode_id])
    utils.print_out("    Chatbot: %s" % response)
示例#19
0
    def build_graph(self, hparams, scope=None):
        """Subclass must implement this method.

        Creates a sequence-to-sequence model with dynamic RNN decoder API.
        Args:
          hparams: Hyperparameter configurations.
          scope: VariableScope for the created subgraph; default "dynamic_seq2seq".

        Returns:
          A tuple of the form (logits, loss, final_context_state),
          where:
            logits: float32 Tensor [batch_size x num_decoder_symbols].
            loss: the total loss / batch_size.
            final_context_state: The final state of decoder RNN.

        Raises:
          ValueError: if encoder_type differs from mono and bi, or
            attention_option is not (luong | scaled_luong |
            bahdanau | normed_bahdanau).
        """
        if self.verbose:
            utils.print_out("# creating %s graph" % self.mode)

        dtype = tf.float32
        # TODO: Check if these have a reason to not call self.
        num_layers = hparams.num_layers
        num_gpus = hparams.num_gpus
        with tf.variable_scope(scope or "dynamic_seq2seq", dtype=dtype):
            # Encoder
            encoder_state = self._build_encoder(hparams)

            # Decoder
            logits, sample_id, final_context_state = self._build_decoder(
                encoder_state, hparams)

            # Loss
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                # Compute it on the same gpu as the last cell
                with tf.device(
                        model_helper.get_device_str(num_layers - 1, num_gpus)):
                    loss = self._compute_loss(logits)
            else:
                # Cannot compute loss because we have no target outputs
                loss = None

        return logits, loss, final_context_state, sample_id
示例#20
0
def _external_eval(model, global_step, sess, hparams, iterator,
                   iterator_feed_dict, tgt_file, label, summary_writer,
                   save_on_best_dev):
    """External evaluation such as BLEU and ROUGE scores. If save on best then keep the best scores in the hparams"""
    out_dir = hparams.out_dir
    # Avoids running eval when global step is 0
    decode = global_step > 0
    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)
    # Initialize the iterator
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    # Create the output file for the logs
    output_file = os.path.join(out_dir, "output_%s" % label)
    # Get the scores for the metrics
    scores = chatbot_utils.decode_and_evaluate(
        name=label,
        model=model,
        sess=sess,
        output_file=output_file,
        reference_file=tgt_file,
        metrics=hparams.metrics,
        bpe_delimiter=hparams.bpe_delimiter,
        beam_width=hparams.beam_width,
        eos=hparams.eos,
        number_token=hparams.number_token,
        name_token=hparams.name_token,
        decode=decode
    )
    # Create the summaries and also save the best
    if decode:
        for metric in hparams.metrics:
            # Create the summary
            utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric),
                              scores[metric])
            # Is the current metric score better than the last
            if save_on_best_dev and scores[metric] > getattr(hparams, "best_" + metric):
                # Update the hparams score
                setattr(hparams, "best_" + metric, scores[metric])
                # Save the model which got the best for this metric to file
                model.saver.save(sess,
                                 os.path.join(getattr(hparams, "best_" + metric + "_dir"), "dialogue.ckpt"),
                                 global_step=model.global_step)  # For safety
    # Save the hparams to file
    utils.save_hparams(out_dir, hparams, verbose=True)

    return scores
示例#21
0
def create_or_load_model(model, model_dir, session, out_dir, name):
    """Create translation model and initialize or load parameters in session."""
    start_time = time.time()
    latest_ckpt = tf.train.latest_checkpoint(model_dir)
    if latest_ckpt:
        model.saver.restore(session, latest_ckpt)
        misc_utils.print_out(
            "  loaded %s model parameters from %s, time %.2fs" %
            (name, latest_ckpt, time.time() - start_time))
    else:
        misc_utils.print_out(
            "  created %s model with fresh parameters, time %.2fs." %
            (name, time.time() - start_time))
        session.run(tf.global_variables_initializer())

    global_step = model.global_step.eval(session=session)
    return model, global_step
def _decode_inference_indices(model,
                              sess,
                              output_infer_file,
                              output_infer_summary_prefix,
                              inference_indices,
                              eos,
                              bpe_delimiter,
                              number_token=None,
                              name_token=None):
    """
    Decoding only a specific set of sentences indicated by inference_indices
    :param output_infer:
    :param output_infer_summary_prefix:
    :param inference_indices: A list of sentence indices
    :param eos: the eos token
    :param bpe_delimiter: delimiter used for byte-pair entries
    :return:
    """
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer_file, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer_file,
                                                  'wb')) as f:
        f.write("")  # Write empty string to ensure that the file is created
        # Get the outputs
        outputs, infer_summary = model.decode(sess)

        # Iterate over the sentences we want to process. Use the index to process sentences and the
        # decode_id to create logs
        for sentence_id, decode_id in enumerate(inference_indices):
            # Get the response
            response = chatbot_utils.postprocess_output(
                outputs,
                sentence_id=sentence_id,
                eos=eos,
                bpe_delimiter=bpe_delimiter,
                number_token=number_token,
                name_token=name_token)
            # TODO: add inference_summary if deciding to use attention

            # Write the response to file
            f.write("%s\n" % response)
            utils.print_out("%s\n" % response)
    utils.print_time("  done", start_time)
示例#23
0
def check_stats(stats, global_step, steps_per_stats, hparams, log_f):
    """Print statistics and also check for overflow."""
    # Print statistics for the previous epoch.
    avg_step_time = stats["step_time"] / steps_per_stats
    avg_grad_norm = stats["grad_norm"] / steps_per_stats
    train_ppl = utils.safe_exp(stats["loss"] / stats["predict_count"])
    speed = stats["total_count"] / (1000 * stats["step_time"])
    utils.print_out(
        "  global step %d lr %g "
        "step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s" %
        (global_step, stats["learning_rate"], avg_step_time, speed, train_ppl,
         avg_grad_norm, _get_best_results(hparams)), log_f)

    # Check for overflow
    is_overflow = False
    if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20:
        utils.print_out("  step %d overflow, stop early" % global_step, log_f)
        is_overflow = True

    return is_overflow
示例#24
0
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_output_file, hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_sos=hparams.sos,
                tgt_eos=hparams.eos,
                bpe_delimiter=hparams.bpe_delimiter)
        else:
            nmt_utils.decode_and_evaluate("infer",
                                          loaded_infer_model,
                                          sess,
                                          output_infer,
                                          ref_file=None,
                                          metrics=hparams.metrics,
                                          bpe_delimiter=hparams.bpe_delimiter,
                                          beam_width=hparams.beam_width,
                                          tgt_sos=hparams.sos,
                                          tgt_eos=hparams.eos)
示例#25
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix, inference_indices,
                              tgt_sos, tgt_eos, bpe_delimiter):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            nmt_outputs, infer_summary = model.decode(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_sos=tgt_sos,
                tgt_eos=tgt_eos,
                bpe_delimiter=bpe_delimiter)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(
                    decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    img_f.write(image_summ.value[0].image.encoded_image_string)

            trans_f.write("%s\n" % translation)
            utils.print_out(b"%s\n" % translation)
    utils.print_time("  done", start_time)
示例#26
0
def _external_eval(model, global_step, sess, hparams, iterator,
                   iterator_feed_dict, tgt_file, label, summary_writer,
                   save_on_best):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir
    decode = global_step > 0
    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)
    scores = nmt_utils.decode_and_evaluate(label,
                                           model,
                                           sess,
                                           output,
                                           ref_file=tgt_file,
                                           metrics=hparams.metrics,
                                           bpe_delimiter=hparams.bpe_delimiter,
                                           beam_width=hparams.beam_width,
                                           tgt_sos=hparams.sos,
                                           tgt_eos=hparams.eos,
                                           decode=decode)
    # Save on best metrics
    if decode:
        for metric in hparams.metrics:
            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            if save_on_best and scores[metric] > getattr(
                    hparams, "best_" + metric):
                setattr(hparams, "best_" + metric, scores[metric])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             "best_" + metric + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
        utils.save_hparams(out_dir, hparams)
    return scores
示例#27
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    if default_hparams.override_loaded_hparams:
        for key in default_config:
            if getattr(hparams, key) != default_config[key]:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                                (key, str(getattr(
                                    hparams, key)), str(default_config[key])))
                setattr(hparams, key, default_config[key])
    return hparams
示例#28
0
    def _get_learning_rate_warmup(self, hparams):
        """Get learning rate warmup."""
        warmup_steps = hparams.warmup_steps
        warmup_scheme = hparams.warmup_scheme
        utils.print_out(
            "  learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" %
            (hparams.learning_rate, warmup_steps, warmup_scheme))

        # Apply inverse decay if global steps less than warmup steps.
        # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3)
        # When step < warmup_steps,
        #   learing_rate *= warmup_factor ** (warmup_steps - step)
        if warmup_scheme == "t2t":
            # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller
            warmup_factor = tf.exp(tf.log(0.01) / warmup_steps)
            inv_decay = warmup_factor**(tf.to_float(warmup_steps -
                                                    self.global_step))
        else:
            raise ValueError("Unknown warmup scheme %s" % warmup_scheme)

        return tf.cond(self.global_step < hparams.warmup_steps,
                       lambda: inv_decay * self.learning_rate,
                       lambda: self.learning_rate,
                       name="learning_rate_warump_cond")
示例#29
0
def run_main(flags,
             default_hparams,
             train_fn,
             inference_fn,
             target_session=""):
    """Run main."""
    # Job
    jobid = flags.jobid
    num_workers = flags.num_workers
    utils.print_out("# Job id %d" % jobid)

    # Random
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed + jobid)
        np.random.seed(random_seed + jobid)

    ## Train / Decode
    out_dir = flags.out_dir
    if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    hparams = create_or_load_hparams(out_dir, default_hparams,
                                     flags.hparams_path)

    if flags.inference_input_file:
        # Inference indices
        hparams.inference_indices = None
        if flags.inference_list:
            (hparams.inference_indices) = ([
                int(token) for token in flags.inference_list.split(",")
            ])

        # Inference
        trans_file = flags.inference_output_file
        ckpt = flags.ckpt
        if not ckpt:
            ckpt = tf.train.latest_checkpoint(out_dir)
        inference_fn(ckpt, flags.inference_input_file, trans_file, hparams,
                     num_workers, jobid)

        # Evaluation
        ref_file = flags.inference_ref_file
        if ref_file and tf.gfile.Exists(trans_file):
            for metric in hparams.metrics:
                score = evaluation_utils.evaluate(ref_file, trans_file, metric,
                                                  hparams.bpe_delimiter)
                utils.print_out("  %s: %.1f" % (metric, score))
    else:
        # Train
        train_fn(hparams, target_session=target_session)
示例#30
0
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    iterator_feed_dict = {
        iterator_src_placeholder: src_data[-hparams.infer_batch_size:],
        iterator_batch_size_placeholder: hparams.infer_batch_size,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    nmt_outputs, att_w_history, ext_w_history = model.decode(sess)

    if hparams.beam_width > 0:
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    nmt_outputs = np.asarray(nmt_outputs)

    outputs = []
    for i in range(hparams.infer_batch_size):
        tmp = {}
        translation = nmt_utils.get_translation(
            nmt_outputs,
            sent_id=i,
            tgt_sos=hparams.sos,
            tgt_eos=hparams.eos,
            bpe_delimiter=hparams.bpe_delimiter)
        if i <= 5:
            utils.print_out("    src: %s" %
                            src_data[-hparams.infer_batch_size + i])
            utils.print_out("    ref: %s" %
                            tgt_data[-hparams.infer_batch_size + i])
            utils.print_out(b"    nmt: %s" % translation)
        tmp['src'] = src_data[-hparams.infer_batch_size + i]
        tmp['ref'] = tgt_data[-hparams.infer_batch_size + i]
        tmp['nmt'] = translation
        if att_w_history is not None:
            tmp['attention_head'] = att_w_history[-hparams.infer_batch_size +
                                                  i]
        if ext_w_history is not None:
            for j, ext_head in enumerate(ext_w_history):
                tmp['ext_head_{0}'.format(j)] = ext_head[
                    -hparams.infer_batch_size + i]
        outputs.append(tmp)

    if hparams.record_w_history:
        with open(
                hparams.out_dir + '/heads_step_{0}.pickle'.format(global_step),
                'wb') as f:
            if len(outputs) > 0:
                pickle.dump(outputs, f)