Пример #1
0
def check_vocab(vocab_file, out_dir, check_special_token=True, sos=None,
                eos=None, unk=None):
  """Check if vocab_file doesn't exist, create from corpus_file."""
  if tf.gfile.Exists(vocab_file):
    utils.print_out("# Vocab file %s exists" % vocab_file)
    vocab, vocab_size = load_vocab(vocab_file)
    if check_special_token:
      # Verify if the vocab starts with unk, sos, eos
      # If not, prepend those tokens & generate a new vocab file
      if not unk: unk = UNK
      if not sos: sos = SOS
      if not eos: eos = EOS
      assert len(vocab) >= 3
      if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos:
        utils.print_out("The first 3 vocab words [%s, %s, %s]"
                        " are not [%s, %s, %s]" %
                        (vocab[0], vocab[1], vocab[2], unk, sos, eos))
        vocab = [unk, sos, eos] + vocab
        vocab_size += 3
        new_vocab_file = os.path.join(out_dir, os.path.basename(vocab_file))
        with codecs.getwriter("utf-8")(
            tf.gfile.GFile(new_vocab_file, "wb")) as f:
          for word in vocab:
            f.write("%s\n" % word)
        vocab_file = new_vocab_file
  else:
    raise ValueError("vocab_file '%s' does not exist." % vocab_file)

  vocab_size = len(vocab)
  return vocab_size, vocab_file
Пример #2
0
def load_model(model, ckpt, session, name):
    start_time = time.time()
    model.saver.restore(session, ckpt)
    session.run(tf.tables_initializer())
    utils.print_out("  loaded %s model parameters from %s, time %.2fs" %
                    (name, ckpt, time.time() - start_time))
    return model
Пример #3
0
def _cell_list(unit_type,
               num_units,
               num_layers,
               num_residual_layers,
               forget_bias,
               dropout,
               mode,
               num_gpus,
               base_gpu=0,
               single_cell_fn=None,
               residual_fn=None):
    """Create a list of RNN cells."""
    if not single_cell_fn:
        single_cell_fn = _single_cell

    # Multi-GPU
    cell_list = []
    for i in range(num_layers):
        utils.print_out("  cell %d" % i, new_line=False)
        single_cell = single_cell_fn(
            unit_type=unit_type,
            num_units=num_units,
            forget_bias=forget_bias,
            dropout=dropout,
            mode=mode,
            residual_connection=(i >= num_layers - num_residual_layers),
            device_str=get_device_str(i + base_gpu, num_gpus),
            residual_fn=residual_fn)
        utils.print_out("")
        cell_list.append(single_cell)

    return cell_list
Пример #4
0
Файл: train.py Проект: buzzf/NLP
def print_step_info(prefix, global_step, info, result_summary, log_f):
    """Print all info at the current global step."""
    utils.print_out(
        "%sstep %d lr %g step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s, %s" %
        (prefix, global_step, info["learning_rate"], info["avg_step_time"],
         info["speed"], info["train_ppl"], info["avg_grad_norm"],
         result_summary, time.ctime()), log_f)
Пример #5
0
Файл: train.py Проект: buzzf/NLP
def _external_eval(model,
                   global_step,
                   sess,
                   hparams,
                   iterator,
                   iterator_feed_dict,
                   tgt_file,
                   label,
                   summary_writer,
                   save_on_best,
                   avg_ckpts=False):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir
    decode = global_step > 0

    if avg_ckpts:
        label = "avg_" + label

    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)
    scores = nmt_utils.decode_and_evaluate(
        label,
        model,
        sess,
        output,
        ref_file=tgt_file,
        metrics=hparams.metrics,
        subword_option=hparams.subword_option,
        beam_width=hparams.beam_width,
        tgt_eos=hparams.eos,
        decode=decode)
    # Save on best metrics
    if decode:
        for metric in hparams.metrics:
            if avg_ckpts:
                best_metric_label = "avg_best_" + metric
            else:
                best_metric_label = "best_" + metric

            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            if save_on_best and scores[metric] > getattr(
                    hparams, best_metric_label):
                setattr(hparams, best_metric_label, scores[metric])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             best_metric_label + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
        utils.save_hparams(out_dir, hparams)
    return scores
Пример #6
0
Файл: train.py Проект: buzzf/NLP
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  # %d" % decode_id)

    iterator_feed_dict = {
        iterator_src_placeholder: [src_data[decode_id]],
        iterator_batch_size_placeholder: 1,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    nmt_outputs, attention_summary = model.decode(sess)

    if hparams.beam_width > 0:
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    translation = nmt_utils.get_translation(
        nmt_outputs,
        sent_id=0,
        tgt_eos=hparams.eos,
        subword_option=hparams.subword_option)
    utils.print_out("    src: %s" % src_data[decode_id])
    utils.print_out("    ref: %s" % tgt_data[decode_id])
    utils.print_out(b"    nmt: " + translation)

    # Summary
    if attention_summary is not None:
        summary_writer.add_summary(attention_summary, global_step)
Пример #7
0
def create_or_load_model(model, model_dir, session, name):
    """Create translation model and initialize or load parameters in session."""
    latest_ckpt = tf.train.latest_checkpoint(model_dir)
    if latest_ckpt:
        model = load_model(model, latest_ckpt, session, name)
    else:
        start_time = time.time()
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        utils.print_out(
            "  created %s model with fresh parameters, time %.2fs" %
            (name, time.time() - start_time))

    global_step = model.global_step.eval(session=session)
    return model, global_step
Пример #8
0
Файл: train.py Проект: buzzf/NLP
def process_stats(stats, info, global_step, steps_per_stats, log_f):
    """Update info and check for overflow."""
    # Update info
    info["avg_step_time"] = stats["step_time"] / steps_per_stats
    info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats
    info["train_ppl"] = utils.safe_exp(stats["loss"] / stats["predict_count"])
    info["speed"] = stats["total_count"] / (1000 * stats["step_time"])

    # Check for overflow
    is_overflow = False
    train_ppl = info["train_ppl"]
    if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20:
        utils.print_out("  step %d overflow, stop early" % global_step, log_f)
        is_overflow = True

    return is_overflow
Пример #9
0
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_output_file, hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input)
Пример #10
0
def _create_pretrained_emb_from_txt(vocab_file,
                                    embed_file,
                                    num_trainable_tokens=3,
                                    dtype=tf.float32,
                                    scope=None):
    """Load pretrain embeding from embed_file, and return an embedding matrix.

  Args:
    embed_file: Path to a Glove formated embedding txt file.
    num_trainable_tokens: Make the first n tokens in the vocab file as trainable
      variables. Default is 3, which is "<unk>", "<s>" and "</s>".
  """
    vocab, _ = vocab_utils.load_vocab(vocab_file)
    trainable_tokens = vocab[:num_trainable_tokens]

    utils.print_out("# Using pretrained embedding: %s." % embed_file)
    utils.print_out("  with trainable tokens: ")

    emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file)
    for token in trainable_tokens:
        utils.print_out("    %s" % token)
        if token not in emb_dict:
            emb_dict[token] = [0.0] * emb_size

    emb_mat = np.array([emb_dict[token] for token in vocab],
                       dtype=dtype.as_numpy_dtype())
    emb_mat = tf.constant(emb_mat)
    emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1])
    with tf.variable_scope(scope or "pretrain_embeddings",
                           dtype=dtype) as scope:
        with tf.device(_get_embed_device(num_trainable_tokens)):
            emb_mat_var = tf.get_variable("emb_mat_var",
                                          [num_trainable_tokens, emb_size])
    return tf.concat([emb_mat_var, emb_mat_const], 0)
Пример #11
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix, inference_indices,
                              tgt_eos, subword_option):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            nmt_outputs, infer_summary = model.decode(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_eos=tgt_eos,
                subword_option=subword_option)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(
                    decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    img_f.write(image_summ.value[0].image.encoded_image_string)

            trans_f.write("%s\n" % translation)
            utils.print_out(translation + b"\n")
    utils.print_time("  done", start_time)
Пример #12
0
Файл: nmt.py Проект: buzzf/NLP
def ensure_compatible_hparams(hparams, default_hparams, hparams_path):
    """Make sure the loaded hparams is compatible with new changes."""
    default_hparams = utils.maybe_parse_standard_hparams(
        default_hparams, hparams_path)

    # For compatible reason, if there are new fields in default_hparams,
    #   we add them to the current hparams
    default_config = default_hparams.values()
    config = hparams.values()
    for key in default_config:
        if key not in config:
            hparams.add_hparam(key, default_config[key])

    # Update all hparams' keys if override_loaded_hparams=True
    if default_hparams.override_loaded_hparams:
        for key in default_config:
            if getattr(hparams, key) != default_config[key]:
                utils.print_out("# Updating hparams.%s: %s -> %s" %
                                (key, str(getattr(
                                    hparams, key)), str(default_config[key])))
                setattr(hparams, key, default_config[key])
    return hparams
Пример #13
0
Файл: train.py Проект: buzzf/NLP
def before_train(loaded_train_model, train_model, train_sess, global_step,
                 hparams, log_f):
    """Misc tasks to do before training."""
    stats = init_stats()
    info = {
        "train_ppl": 0.0,
        "speed": 0.0,
        "avg_step_time": 0.0,
        "avg_grad_norm": 0.0,
        "learning_rate":
        loaded_train_model.learning_rate.eval(session=train_sess)
    }
    start_train_time = time.time()
    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, info["learning_rate"], time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})

    return stats, info, start_train_time
Пример #14
0
Файл: nmt.py Проект: buzzf/NLP
def run_main(flags,
             default_hparams,
             train_fn,
             inference_fn,
             target_session=""):
    """Run main."""
    # Job
    jobid = flags.jobid
    num_workers = flags.num_workers
    utils.print_out("# Job id %d" % jobid)

    # Random
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed + jobid)
        np.random.seed(random_seed + jobid)

    ## Train / Decode
    out_dir = flags.out_dir
    if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir)

    # Load hparams.如果有新写入的参数需要重新写进去,并缓存在本地
    hparams = create_or_load_hparams(out_dir,
                                     default_hparams,
                                     flags.hparams_path,
                                     save_hparams=(jobid == 0))

    if flags.inference_input_file:
        # Inference indices
        hparams.inference_indices = None
        if flags.inference_list:
            (hparams.inference_indices) = ([
                int(token) for token in flags.inference_list.split(",")
            ])

        # Inference
        trans_file = flags.inference_output_file
        ckpt = flags.ckpt
        if not ckpt:
            ckpt = tf.train.latest_checkpoint(out_dir)
        inference_fn(ckpt, flags.inference_input_file, trans_file, hparams,
                     num_workers, jobid)

        # Evaluation
        ref_file = flags.inference_ref_file
        if ref_file and tf.gfile.Exists(trans_file):
            for metric in hparams.metrics:
                score = evaluation_utils.evaluate(ref_file, trans_file, metric,
                                                  hparams.subword_option)
                utils.print_out("  %s: %.1f" % (metric, score))
    else:
        # Train
        train_fn(hparams, target_session=target_session)
Пример #15
0
    def _build_encoder(self, hparams):
        """Build a GNMT encoder."""
        if hparams.encoder_type == "uni" or hparams.encoder_type == "bi":
            return super(GNMTModel, self)._build_encoder(hparams)

        if hparams.encoder_type != "gnmt":
            raise ValueError("Unknown encoder_type %s" % hparams.encoder_type)

        # Build GNMT encoder.
        num_bi_layers = 1
        num_uni_layers = self.num_encoder_layers - num_bi_layers
        utils.print_out("  num_bi_layers = %d" % num_bi_layers)
        utils.print_out("  num_uni_layers = %d" % num_uni_layers)

        iterator = self.iterator
        source = iterator.source
        if self.time_major:
            source = tf.transpose(source)

        with tf.variable_scope("encoder") as scope:
            dtype = scope.dtype

            # Look up embedding, emp_inp: [max_time, batch_size, num_units]
            #   when time_major = True
            encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder,
                                                     source)

            # Execute _build_bidirectional_rnn from Model class
            bi_encoder_outputs, bi_encoder_state = self._build_bidirectional_rnn(
                inputs=encoder_emb_inp,
                sequence_length=iterator.source_sequence_length,
                dtype=dtype,
                hparams=hparams,
                num_bi_layers=num_bi_layers,
                num_bi_residual_layers=0,  # no residual connection
            )

            uni_cell = model_helper.create_rnn_cell(
                unit_type=hparams.unit_type,
                num_units=hparams.num_units,
                num_layers=num_uni_layers,
                num_residual_layers=self.num_encoder_residual_layers,
                forget_bias=hparams.forget_bias,
                dropout=hparams.dropout,
                num_gpus=self.num_gpus,
                base_gpu=1,
                mode=self.mode,
                single_cell_fn=self.single_cell_fn)

            # encoder_outputs: size [max_time, batch_size, num_units]
            #   when time_major = True
            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                uni_cell,
                bi_encoder_outputs,
                dtype=dtype,
                sequence_length=iterator.source_sequence_length,
                time_major=self.time_major)

            # Pass all encoder state except the first bi-directional layer's state to
            # decoder.
            encoder_state = (bi_encoder_state[1], ) + (
                (encoder_state, ) if num_uni_layers == 1 else encoder_state)

        return encoder_outputs, encoder_state
Пример #16
0
def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        subword_option,
                        beam_width,
                        tgt_eos,
                        num_translations_per_input=1,
                        decode=True):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s." % trans_file)

        start_time = time.time()
        num_sentences = 0
        with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                      mode="wb")) as trans_f:
            trans_f.write("")  # Write empty string to ensure file is created.

            num_translations_per_input = max(
                min(num_translations_per_input, beam_width), 1)
            while True:
                try:
                    nmt_outputs, _ = model.decode(sess)
                    if beam_width == 0:
                        nmt_outputs = np.expand_dims(nmt_outputs, 0)

                    batch_size = nmt_outputs.shape[1]
                    num_sentences += batch_size

                    for sent_id in range(batch_size):
                        for beam_id in range(num_translations_per_input):
                            translation = get_translation(
                                nmt_outputs[beam_id],
                                sent_id,
                                tgt_eos=tgt_eos,
                                subword_option=subword_option)
                            trans_f.write(
                                (translation + b"\n").decode("utf-8"))
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d, num translations per input %d"
                        % (num_sentences, num_translations_per_input),
                        start_time)
                    break

    # Evaluation
    evaluation_scores = {}
    if ref_file and tf.gfile.Exists(trans_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(ref_file,
                                              trans_file,
                                              metric,
                                              subword_option=subword_option)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
Пример #17
0
def multi_worker_inference(infer_model, ckpt, inference_input_file,
                           inference_output_file, hparams, num_workers, jobid):
    """Inference using multiple workers."""
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, jobid)
    output_infer_done = "%s_done_%d" % (inference_output_file, jobid)

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    # Split data to multiple workers
    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = jobid * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(
            infer_model.iterator.initializer, {
                infer_model.src_placeholder: infer_data,
                infer_model.batch_size_placeholder: hparams.infer_batch_size
            })
        # Decode
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)

        # Change file name to indicate the file writing is completed.
        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        # Job 0 is responsible for the clean up.
        if jobid != 0: return

        # Now write all translations
        with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer,
                                                      mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." %
                                    worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(tf.gfile.GFile(
                        worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)

            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                tf.gfile.Remove(worker_infer_done)
Пример #18
0
def _single_cell(unit_type,
                 num_units,
                 forget_bias,
                 dropout,
                 mode,
                 residual_connection=False,
                 device_str=None,
                 residual_fn=None):
    """Create an instance of a single RNN cell."""
    # dropout (= 1 - keep_prob) is set to 0 during eval and infer
    dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0

    # Cell Type
    if unit_type == "lstm":
        utils.print_out("  LSTM, forget_bias=%g" % forget_bias, new_line=False)
        single_cell = tf.contrib.rnn.BasicLSTMCell(num_units,
                                                   forget_bias=forget_bias)
    elif unit_type == "gru":
        utils.print_out("  GRU", new_line=False)
        single_cell = tf.contrib.rnn.GRUCell(num_units)
    elif unit_type == "layer_norm_lstm":
        utils.print_out("  Layer Normalized LSTM, forget_bias=%g" %
                        forget_bias,
                        new_line=False)
        single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
            num_units, forget_bias=forget_bias, layer_norm=True)
    elif unit_type == "nas":
        utils.print_out("  NASCell", new_line=False)
        single_cell = tf.contrib.rnn.NASCell(num_units)
    else:
        raise ValueError("Unknown unit type %s!" % unit_type)

    # Dropout (= 1 - keep_prob)
    if dropout > 0.0:
        single_cell = tf.contrib.rnn.DropoutWrapper(cell=single_cell,
                                                    input_keep_prob=(1.0 -
                                                                     dropout))
        utils.print_out("  %s, dropout=%g " %
                        (type(single_cell).__name__, dropout),
                        new_line=False)

    # Residual
    if residual_connection:
        single_cell = tf.contrib.rnn.ResidualWrapper(single_cell,
                                                     residual_fn=residual_fn)
        utils.print_out("  %s" % type(single_cell).__name__, new_line=False)

    # Device Wrapper
    if device_str:
        single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)
        utils.print_out("  %s, device=%s" %
                        (type(single_cell).__name__, device_str),
                        new_line=False)

    return single_cell
Пример #19
0
def avg_checkpoints(model_dir, num_last_checkpoints, global_step,
                    global_step_name):
    """Average the last N checkpoints in the model_dir."""
    checkpoint_state = tf.train.get_checkpoint_state(model_dir)
    if not checkpoint_state:
        utils.print_out("# No checkpoint file found in directory: %s" %
                        model_dir)
        return None

    # Checkpoints are ordered from oldest to newest.
    checkpoints = (
        checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])

    if len(checkpoints) < num_last_checkpoints:
        utils.print_out(
            "# Skipping averaging checkpoints because not enough checkpoints is "
            "avaliable.")
        return None

    avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
    if not tf.gfile.Exists(avg_model_dir):
        utils.print_out(
            "# Creating new directory %s for saving averaged checkpoints." %
            avg_model_dir)
        tf.gfile.MakeDirs(avg_model_dir)

    utils.print_out("# Reading and averaging variables in checkpoints:")
    var_list = tf.contrib.framework.list_variables(checkpoints[0])
    var_values, var_dtypes = {}, {}
    for (name, shape) in var_list:
        if name != global_step_name:
            var_values[name] = np.zeros(shape)

    for checkpoint in checkpoints:
        utils.print_out("    %s" % checkpoint)
        reader = tf.contrib.framework.load_checkpoint(checkpoint)
        for name in var_values:
            tensor = reader.get_tensor(name)
            var_dtypes[name] = tensor.dtype
            var_values[name] += tensor

    for name in var_values:
        var_values[name] /= len(checkpoints)

    # Build a graph with same variables in the checkpoints, and save the averaged
    # variables into the avg_model_dir.
    with tf.Graph().as_default():
        tf_vars = [
            tf.get_variable(v,
                            shape=var_values[v].shape,
                            dtype=var_dtypes[name]) for v in var_values
        ]

        placeholders = [
            tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars
        ]
        assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
        global_step_var = tf.Variable(global_step,
                                      name=global_step_name,
                                      trainable=False)
        saver = tf.train.Saver(tf.all_variables())

        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                                   six.iteritems(var_values)):
                sess.run(assign_op, {p: value})

            # Use the built saver to save the averaged checkpoint. Only keep 1
            # checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
            saver.save(sess, os.path.join(avg_model_dir, "translate.ckpt"))

    return avg_model_dir
Пример #20
0
Файл: nmt.py Проект: buzzf/NLP
def extend_hparams(hparams):
    """Extend training hparams."""
    assert hparams.num_encoder_layers and hparams.num_decoder_layers
    if hparams.num_encoder_layers != hparams.num_decoder_layers:
        hparams.pass_hidden_state = False
        utils.print_out(
            "Num encoder layer %d is different from num decoder layer"
            " %d, so set pass_hidden_state to False" %
            (hparams.num_encoder_layers, hparams.num_decoder_layers))

    # Sanity checks
    if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0:
        raise ValueError("For bi, num_encoder_layers %d should be even" %
                         hparams.num_encoder_layers)
    if (hparams.attention_architecture in ["gnmt"]
            and hparams.num_encoder_layers < 2):
        raise ValueError("For gnmt attention architecture, "
                         "num_encoder_layers %d should be >= 2" %
                         hparams.num_encoder_layers)

    # Set residual layers
    num_encoder_residual_layers = 0
    num_decoder_residual_layers = 0
    if hparams.residual:
        if hparams.num_encoder_layers > 1:
            num_encoder_residual_layers = hparams.num_encoder_layers - 1
        if hparams.num_decoder_layers > 1:
            num_decoder_residual_layers = hparams.num_decoder_layers - 1

        if hparams.encoder_type == "gnmt":
            # The first unidirectional layer (after the bi-directional layer) in
            # the GNMT encoder can't have residual connection due to the input is
            # the concatenation of fw_cell and bw_cell's outputs.
            num_encoder_residual_layers = hparams.num_encoder_layers - 2

            # Compatible for GNMT models
            if hparams.num_encoder_layers == hparams.num_decoder_layers:
                num_decoder_residual_layers = num_encoder_residual_layers
    hparams.add_hparam("num_encoder_residual_layers",
                       num_encoder_residual_layers)
    hparams.add_hparam("num_decoder_residual_layers",
                       num_decoder_residual_layers)

    if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]:
        raise ValueError("subword option must be either spm, or bpe")

    # Flags
    utils.print_out("# hparams:")
    utils.print_out("  src=%s" % hparams.src)
    utils.print_out("  tgt=%s" % hparams.tgt)
    utils.print_out("  train_prefix=%s" % hparams.train_prefix)
    utils.print_out("  dev_prefix=%s" % hparams.dev_prefix)
    utils.print_out("  test_prefix=%s" % hparams.test_prefix)
    utils.print_out("  out_dir=%s" % hparams.out_dir)

    ## Vocab
    # Get vocab file names first
    if hparams.vocab_prefix:
        src_vocab_file = hparams.vocab_prefix + "." + hparams.src
        tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
    else:
        raise ValueError("hparams.vocab_prefix must be provided.")

    # Source vocab
    src_vocab_size, src_vocab_file = vocab_utils.check_vocab(
        src_vocab_file,
        hparams.out_dir,
        check_special_token=hparams.check_special_token,
        sos=hparams.sos,
        eos=hparams.eos,
        unk=vocab_utils.UNK)

    # Target vocab
    if hparams.share_vocab:
        utils.print_out("  using source vocab for target")
        tgt_vocab_file = src_vocab_file
        tgt_vocab_size = src_vocab_size
    else:
        tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(
            tgt_vocab_file,
            hparams.out_dir,
            check_special_token=hparams.check_special_token,
            sos=hparams.sos,
            eos=hparams.eos,
            unk=vocab_utils.UNK)
    hparams.add_hparam("src_vocab_size", src_vocab_size)
    hparams.add_hparam("tgt_vocab_size", tgt_vocab_size)
    hparams.add_hparam("src_vocab_file", src_vocab_file)
    hparams.add_hparam("tgt_vocab_file", tgt_vocab_file)

    # Pretrained Embeddings:
    hparams.add_hparam("src_embed_file", "")
    hparams.add_hparam("tgt_embed_file", "")
    if hparams.embed_prefix:
        src_embed_file = hparams.embed_prefix + "." + hparams.src
        tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt

        if tf.gfile.Exists(src_embed_file):
            hparams.src_embed_file = src_embed_file

        if tf.gfile.Exists(tgt_embed_file):
            hparams.tgt_embed_file = tgt_embed_file

    # Check out_dir
    if not tf.gfile.Exists(hparams.out_dir):
        utils.print_out("# Creating output directory %s ..." % hparams.out_dir)
        tf.gfile.MakeDirs(hparams.out_dir)

    # Evaluation
    for metric in hparams.metrics:
        hparams.add_hparam("best_" + metric, 0)  # larger is better
        best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric)
        hparams.add_hparam("best_" + metric + "_dir", best_metric_dir)
        tf.gfile.MakeDirs(best_metric_dir)

        if hparams.avg_ckpts:
            hparams.add_hparam("avg_best_" + metric, 0)  # larger is better
            best_metric_dir = os.path.join(hparams.out_dir,
                                           "avg_best_" + metric)
            hparams.add_hparam("avg_best_" + metric + "_dir", best_metric_dir)
            tf.gfile.MakeDirs(best_metric_dir)

    return hparams
Пример #21
0
Файл: train.py Проект: buzzf/NLP
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts))
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step
Пример #22
0
def create_emb_for_encoder_and_decoder(share_vocab,
                                       src_vocab_size,
                                       tgt_vocab_size,
                                       src_embed_size,
                                       tgt_embed_size,
                                       dtype=tf.float32,
                                       num_partitions=0,
                                       src_vocab_file=None,
                                       tgt_vocab_file=None,
                                       src_embed_file=None,
                                       tgt_embed_file=None,
                                       scope=None):
    """Create embedding matrix for both encoder and decoder.

  Args:
    share_vocab: A boolean. Whether to share embedding matrix for both
      encoder and decoder.
    src_vocab_size: An integer. The source vocab size.
    tgt_vocab_size: An integer. The target vocab size.
    src_embed_size: An integer. The embedding dimension for the encoder's
      embedding.
    tgt_embed_size: An integer. The embedding dimension for the decoder's
      embedding.
    dtype: dtype of the embedding matrix. Default to float32.
    num_partitions: number of partitions used for the embedding vars.
    scope: VariableScope for the created subgraph. Default to "embedding".

  Returns:
    embedding_encoder: Encoder's embedding matrix.
    embedding_decoder: Decoder's embedding matrix.

  Raises:
    ValueError: if use share_vocab but source and target have different vocab
      size.
  """

    if num_partitions <= 1:
        partitioner = None
    else:
        # Note: num_partitions > 1 is required for distributed training due to
        # embedding_lookup tries to colocate single partition-ed embedding variable
        # with lookup ops. This may cause embedding variables being placed on worker
        # jobs.
        partitioner = tf.fixed_size_partitioner(num_partitions)

    if (src_embed_file or tgt_embed_file) and partitioner:
        raise ValueError(
            "Can't set num_partitions > 1 when using pretrained embedding")

    with tf.variable_scope(scope or "embeddings",
                           dtype=dtype,
                           partitioner=partitioner) as scope:
        # Share embedding
        if share_vocab:
            if src_vocab_size != tgt_vocab_size:
                raise ValueError(
                    "Share embedding but different src/tgt vocab sizes"
                    " %d vs. %d" % (src_vocab_size, tgt_vocab_size))
            assert src_embed_size == tgt_embed_size
            utils.print_out("# Use the same embedding for source and target")
            vocab_file = src_vocab_file or tgt_vocab_file
            embed_file = src_embed_file or tgt_embed_file

            embedding_encoder = _create_or_load_embed("embedding_share",
                                                      vocab_file, embed_file,
                                                      src_vocab_size,
                                                      src_embed_size, dtype)
            embedding_decoder = embedding_encoder
        else:
            with tf.variable_scope("encoder", partitioner=partitioner):
                embedding_encoder = _create_or_load_embed(
                    "embedding_encoder", src_vocab_file, src_embed_file,
                    src_vocab_size, src_embed_size, dtype)

            with tf.variable_scope("decoder", partitioner=partitioner):
                embedding_decoder = _create_or_load_embed(
                    "embedding_decoder", tgt_vocab_file, tgt_embed_file,
                    tgt_vocab_size, tgt_embed_size, dtype)

    return embedding_encoder, embedding_decoder