def _cell_list(unit_type,
               num_units,
               num_layers,
               num_residual_layers,
               forget_bias,
               dropout,
               mode,
               num_gpus,
               base_gpu=0,
               single_cell_fn=None,
               residual_fn=None):
  """Create a list of RNN cells."""
  if not single_cell_fn:
    single_cell_fn = _single_cell

  # Multi-GPU
  cell_list = []
  for i in range(num_layers):
    utils.print_out("  cell %d" % i, new_line=False)
    single_cell = single_cell_fn(
        unit_type=unit_type,
        num_units=num_units,
        forget_bias=forget_bias,
        dropout=dropout,
        mode=mode,
        residual_connection=(i >= num_layers - num_residual_layers),
        device_str=get_device_str(i + base_gpu, num_gpus),
        residual_fn=residual_fn)
    utils.print_out("")
    cell_list.append(single_cell)

  return cell_list
Example #2
0
    def _get_learning_rate_decay(self, hparams):
        """Get learning rate decay."""
        if hparams.decay_scheme == 'luong10':
            start_decay_step = int(hparams.num_train_steps / 2)
            remain_steps = hparams.num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 10)  # decay 10 times
            decay_factor = 0.5
        elif hparams.decay_scheme == 'luong234':
            start_decay_step = int(hparams.num_train_steps * 2 / 3)
            remain_steps = hparams.num_train_steps - start_decay_step
            decay_steps = int(remain_steps / 4)  # decay 4 times
            decay_factor = 0.5
        elif not hparams.decay_scheme:  # no decay
            start_decay_step = hparams.num_train_steps
            decay_steps = 0
            decay_factor = 1.0
        elif hparams.decay_scheme:
            raise ValueError('Unknown decay scheme %s' % hparams.decay_scheme)
        utils.print_out(
            '  decay_scheme=%s, start_decay_step=%d, decay_steps %d, '
            'decay_factor %g' % (hparams.decay_scheme, start_decay_step,
                                 decay_steps, decay_factor))

        return tf.cond(self.global_step < start_decay_step,
                       lambda: self.learning_rate,
                       lambda: tf.train.exponential_decay(self.learning_rate, (
                           self.global_step - start_decay_step),
                                                          decay_steps,
                                                          decay_factor,
                                                          staircase=True),
                       name='learning_rate_decay_cond')
def compute_perplexity(hparams, model, sess, name):
  """Compute perplexity of the output of the model.

  Args:
    hparams: holds the parameters.
    model: model for compute perplexity.
    sess: tensorflow session to use.
    name: name of the batch.

  Returns:
    The perplexity of the eval outputs.
  """
  total_loss = 0
  total_predict_count = 0
  start_time = time.time()
  step = 0
  start_time_step = time.time()
  while True:
    try:
      loss, _, _, predict_count, batch_size = model.eval(sess)
      total_loss += loss * batch_size
      total_predict_count += predict_count
      if step % hparams.steps_per_stats == 0:
        # print_time does not print decimal places for time.
        utils.print_out("  computing perplexity %s, step %d, time %.3f" %
                        (name, step, time.time() - start_time_step))
      step += 1
      start_time_step = time.time()
    except tf.errors.OutOfRangeError:
      break

  perplexity = utils.safe_exp(total_loss / total_predict_count)
  utils.print_time("  eval %s: perplexity %.2f" % (name, perplexity),
                   start_time)
  return perplexity
def print_step_info(prefix, global_step, info, result_summary, log_f):
    """Print all info at the current global step."""
    utils.print_out(
        "%sstep %d lr %g step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s, %s" %
        (prefix, global_step, info["learning_rate"], info["avg_step_time"],
         info["speed"], info["train_ppl"], info["avg_grad_norm"],
         result_summary, time.ctime()), log_f)
def load_model(model, ckpt, session, name):
    start_time = time.time()
    model.saver.restore(session, ckpt)
    session.run(tf.tables_initializer())
    utils.print_out("  loaded %s model parameters from %s, time %.2fs" %
                    (name, ckpt, time.time() - start_time))
    return model
def run_internal_eval(eval_model,
                      sess,
                      hparams,
                      summary_writer,
                      use_test_set=True):
    """Compute internal evaluation (perplexity) for both dev / test."""

    utils.print_out(
        "Computing internal evaluation (perplexity) for both dev / test.")

    with eval_model.graph.as_default():
        global_step = model_helper.get_global_step(eval_model.model, sess)

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    dev_ctx_file = None
    if hparams.ctx is not None:
        dev_ctx_file = "%s.%s" % (hparams.dev_prefix, hparams.ctx)

    dev_iterator_feed_dict = {
        eval_model.src_file_placeholder: dev_src_file,
        eval_model.tgt_file_placeholder: dev_tgt_file,
    }
    if dev_ctx_file is not None:
        dev_iterator_feed_dict[eval_model.ctx_file_placeholder] = dev_ctx_file

    if hparams.dev_annotations is not None:
        dev_iterator_feed_dict[
            eval_model.annot_file_placeholder] = hparams.dev_annotations

    dev_ppl = _internal_eval(hparams, eval_model.model, global_step, sess,
                             eval_model.iterator, dev_iterator_feed_dict,
                             summary_writer, "dev")
    test_ppl = None
    if use_test_set and hparams.test_prefix:
        test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
        test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
        test_ctx_file = None
        if hparams.ctx is not None:
            test_ctx_file = "%s.%s" % (hparams.test_prefix, hparams.ctx)

        test_iterator_feed_dict = {
            eval_model.src_file_placeholder: test_src_file,
            eval_model.tgt_file_placeholder: test_tgt_file,
        }
        if test_ctx_file is not None:
            test_iterator_feed_dict[
                eval_model.ctx_file_placeholder] = test_ctx_file

        if hparams.test_annotations is not None:
            test_iterator_feed_dict[
                eval_model.annot_file_placeholder] = hparams.test_annotations

        test_ppl = _internal_eval(hparams, eval_model.model, global_step, sess,
                                  eval_model.iterator, test_iterator_feed_dict,
                                  summary_writer, "test")
    return dev_ppl, test_ppl
Example #7
0
def _external_eval(model,
                   global_step,
                   sess,
                   hparams,
                   iterator,
                   iterator_feed_dict,
                   tgt_file,
                   label,
                   summary_writer,
                   save_on_best,
                   avg_ckpts=False):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir

    if avg_ckpts:
        label = "avg_" + label

    utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)
    scores = nmt_utils.decode_and_evaluate(
        label,
        model,
        sess,
        output,
        ref_file=tgt_file,
        metrics=hparams.metrics,
        subword_option=hparams.subword_option,
        beam_width=hparams.beam_width,
        tgt_eos=hparams.eos,
        hparams=hparams,
        decode=True)
    # Save on best metrics
    if global_step > 0:
        for metric in hparams.metrics:
            if avg_ckpts:
                best_metric_label = "avg_best_" + metric
            else:
                best_metric_label = "best_" + metric

            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            if save_on_best and scores[metric] > getattr(
                    hparams, best_metric_label):
                setattr(hparams, best_metric_label, scores[metric])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             "best_" + metric + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
        utils.save_hparams(out_dir, hparams)
    return scores
def _internal_eval(hparams, model, global_step, sess, iterator,
                   iterator_feed_dict, summary_writer, label):
    """Computing perplexity."""

    utils.print_out("# Internal evaluation (perplexity), global step %d" %
                    global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    ppl = model_helper.compute_perplexity(hparams, model, sess, label)
    utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl)
    return ppl
Example #9
0
 def _get_infer_maximum_iterations(self, hparams, source_sequence_length):
     """Maximum decoding steps at inference time."""
     if hparams.tgt_max_len_infer:
         maximum_iterations = hparams.tgt_max_len_infer
         utils.print_out('  decoding maximum_iterations %d' %
                         maximum_iterations)
     else:
         decoding_length_factor = 2.0
         max_encoder_length = tf.reduce_max(source_sequence_length)
         maximum_iterations = tf.to_int32(
             tf.round(
                 tf.to_float(max_encoder_length) * decoding_length_factor))
     return maximum_iterations
Example #10
0
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_context_file, inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    if inference_context_file is not None:
        infer_context = load_data(inference_context_file, hparams)
    else:
        infer_context = None

    infer_feed_dict = {
        infer_model.src_placeholder: infer_data,
        infer_model.batch_size_placeholder: hparams.infer_batch_size
    }
    if infer_context is not None:
        infer_feed_dict[infer_model.ctx_placeholder] = infer_context

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer, feed_dict=infer_feed_dict)
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                hparams=hparams,
                num_translations_per_input=hparams.num_translations_per_input)
Example #11
0
def create_or_load_model(model, model_dir, session, name):
    """Create translation model and initialize or load parameters in session."""
    latest_ckpt = tf.train.latest_checkpoint(model_dir)
    if latest_ckpt:
        model = load_model(model, latest_ckpt, session, name)
    else:
        start_time = time.time()
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        utils.print_out(
            "  created %s model with fresh parameters, time %.2fs" %
            (name, time.time() - start_time))

    global_step = model.global_step.eval(session=session)
    return model, global_step
def process_stats(stats, info, global_step, steps_per_stats, log_f):
    """Update info and check for overflow."""
    # Update info
    info["avg_step_time"] = stats["step_time"] / steps_per_stats
    info["avg_grad_norm"] = stats["grad_norm"] / steps_per_stats
    info["train_ppl"] = utils.safe_exp(stats["loss"] / stats["predict_count"])
    info["speed"] = stats["total_count"] / (1000 * stats["step_time"])

    # Check for overflow
    is_overflow = False
    train_ppl = info["train_ppl"]
    if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20:
        utils.print_out("  step %d overflow, stop early" % global_step, log_f)
        is_overflow = True

    return is_overflow
Example #13
0
def check_vocab(vocab_file,
                out_dir,
                check_special_token=True,
                sos=None,
                eos=None,
                unk=None,
                context_delimiter=None):
    """Check if vocab_file doesn't exist, create from corpus_file."""
    if tf.gfile.Exists(vocab_file):
        utils.print_out("# Vocab file %s exists" % vocab_file)
        vocab = []
        with codecs.getreader("utf-8")(tf.gfile.GFile(vocab_file, "rb")) as f:
            vocab_size = 0
            for word in f:
                word = word.rstrip("\n").rsplit("\t", 1)[0]
                vocab_size += 1
                vocab.append(word)
        # add context delimiter if not exist yet
        if context_delimiter is not None and context_delimiter not in vocab:
            vocab += [context_delimiter]
            vocab_size += 1
            utils.print_out("Context delimiter {} does not exist".format(
                context_delimiter))
        elif context_delimiter is not None and context_delimiter in vocab:
            utils.print_out(
                "Context delimiter {} already exists in vocab".format(
                    context_delimiter))

        if check_special_token:
            # Verify if the vocab starts with unk, sos, eos
            # If not, prepend those tokens & generate a new vocab file
            if not unk:
                unk = UNK
            if not sos:
                sos = SOS
            if not eos:
                eos = EOS
            assert len(vocab) >= 3
            if vocab[0] != unk or vocab[1] != sos or vocab[2] != eos:
                utils.print_out("The first 3 vocab words [%s, %s, %s]"
                                " are not [%s, %s, %s]" %
                                (vocab[0], vocab[1], vocab[2], unk, sos, eos))
                vocab = [unk, sos, eos] + vocab
                vocab_size += 3
                new_vocab_file = os.path.join(out_dir,
                                              os.path.basename(vocab_file))
                with codecs.getwriter("utf-8")(tf.gfile.GFile(
                        new_vocab_file, "wb")) as f:
                    for word in vocab:
                        f.write("%s\n" % word)
                vocab_file = new_vocab_file
    else:
        raise ValueError("vocab_file '%s' does not exist." % vocab_file)

    vocab_size = len(vocab)
    return vocab_size, vocab_file
def create_or_load_model(model, model_dir, session, name):
  """Create translation model and initialize or load parameters in session."""
  # p2='C://Users//aanamika//Documents//QuestionGeneration//active-qa-master//tmp//active-qa//translate.ckpt-1460356//'
  p2 = 'C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/temp/translate.ckpt-1460356'
  print('p2:', p2)
  latest_ckpt = tf.train.load_checkpoint(p2) ## model_dir ## latest_checkpoint
  print('latest_ckpt:',latest_ckpt)
  if latest_ckpt:
    print('----------Loaded the checkpoint-------------')
    model = load_model(model, p2, session, name)
  else:
    start_time = time.time()
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    utils.print_out("  created %s model with fresh parameters, time %.2fs" %
                    (name, time.time() - start_time))

  global_step = model.global_step.eval(session=session)
  return model, global_step
Example #15
0
    def train(self, sources, annotations):
        """Trains the reformulator with the given sources.

    Args:
      sources: A list of strings representing the questions.
      annotations: A list of strings representing the document ids.

    Returns:
      Training loss.
      Rewards.
      Rewrites.
    """
        tokenized_sources = self.tokenize(sources, prefix=self.source_prefix)

        iterator_feed_dict = {
            self.train_model.src_placeholder: tokenized_sources,
            self.train_model.tgt_placeholder: tokenized_sources,  # Unused.
            self.train_model.annot_placeholder: annotations,
            self.train_model.skip_count_placeholder: 0
        }

        self.sess.run(self.train_model.iterator.initializer,
                      feed_dict=iterator_feed_dict)
        train_result = self.train_model.model.train(self.sess)
        global_step = self.train_model.model.global_step.eval(
            session=self.sess)
        self.summary_writer.add_summary(train_result[4], global_step)

        # Regularly save a checkpoint.
        if global_step - self.last_save_step >= self.hparams.steps_per_save:
            misc_utils.print_out("Save at step: {}".format(global_step))
            self.last_save_step = global_step
            self.train_model.model.saver.save(self.sess,
                                              self.checkpoint_path,
                                              global_step=global_step)
            if self.trie:
                self.trie.save_to_file(self.trie_save_path +
                                       ".{}.pkl".format(global_step))

        rewrites = [rewrite.decode("utf-8") for rewrite in train_result[10]]

        return train_result[1], train_result[2], rewrites
Example #16
0
def load_embed_txt(embed_file):
    """Load embed_file into a python dictionary.

  Note: the embed_file should be a Glove/word2vec formatted txt file. Assuming
  Here is an exampe assuming embed_size=5:

  the -0.071549 0.093459 0.023738 -0.090339 0.056123
  to 0.57346 0.5417 -0.23477 -0.3624 0.4037
  and 0.20327 0.47348 0.050877 0.002103 0.060547

  For word2vec format, the first line will be: <num_words> <emb_size>.

  Args:
    embed_file: file path to the embedding file.
  Returns:
    a dictionary that maps word to vector, and the size of embedding dimensions.
  """
    emb_dict = dict()
    emb_size = None

    is_first_line = True
    with codecs.getreader("utf-8")(tf.gfile.GFile(embed_file, "rb")) as f:
        for line in f:
            tokens = line.rstrip().split(" ")
            if is_first_line:
                is_first_line = False
                if len(tokens) == 2:  # header line
                    emb_size = int(tokens[1])
                    continue
            word = tokens[0]
            vec = list(map(float, tokens[1:]))
            emb_dict[word] = vec
            if emb_size:
                if emb_size != len(vec):
                    utils.print_out(
                        "Ignoring %s since embeding size is inconsistent." %
                        word)
                    del emb_dict[word]
            else:
                emb_size = len(vec)
    return emb_dict, emb_size
Example #17
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix, inference_indices,
                              tgt_eos, subword_option):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(tf.gfile.GFile(output_infer,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            _, infer_summary, _, nmt_outputs, _ = model.infer(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_eos=tgt_eos,
                subword_option=subword_option)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(
                    decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    img_f.write(image_summ.value[0].image.encoded_image_string)

            trans_f.write("%s\n" % translation)
            utils.print_out(translation + b"\n")
    utils.print_time("  done", start_time)
def _create_pretrained_emb_from_txt(vocab_file,
                                    embed_file,
                                    num_trainable_tokens=3,
                                    dtype=tf.float32,
                                    scope=None):
  """Load pretrain embeding from embed_file, and return an embedding matrix.

  Args:
    embed_file: Path to a Glove formatted embedding txt file.
    num_trainable_tokens: Make the first n tokens in the vocab file as trainable
      variables. Default is 3, which is "<unk>", "<s>" and "</s>".
  """
  vocab, _ = vocab_utils.load_vocab(vocab_file)
  trainable_tokens = vocab[:num_trainable_tokens]

  utils.print_out("# Using pretrained embedding: %s." % embed_file)
  utils.print_out("  with trainable tokens: ")

  emb_dict, emb_size = vocab_utils.load_embed_txt(embed_file)
  for token in trainable_tokens:
    utils.print_out("    %s" % token)
    if token not in emb_dict:
      emb_dict[token] = [0.0] * emb_size

  emb_mat = np.array(
      [emb_dict[token] for token in vocab], dtype=dtype.as_numpy_dtype())
  emb_mat = tf.constant(emb_mat)
  emb_mat_const = tf.slice(emb_mat, [num_trainable_tokens, 0], [-1, -1])
  with tf.variable_scope(scope or "pretrain_embeddings", dtype=dtype) as scope:
    with tf.device(_get_embed_device(num_trainable_tokens)):
      emb_mat_var = tf.get_variable("emb_mat_var",
                                    [num_trainable_tokens, emb_size])
  return tf.concat([emb_mat_var, emb_mat_const], 0)
Example #19
0
def ensure_compatible_hparams(hparams, default_hparams, hparams_path):
  """Make sure the loaded hparams is compatible with new changes."""
  default_hparams = utils.maybe_parse_standard_hparams(default_hparams,
                                                       hparams_path)

  # For compatible reason, if there are new fields in default_hparams,
  #   we add them to the current hparams
  default_config = default_hparams.values()
  config = hparams.values()
  for key in default_config:
    if key not in config:
      hparams.add_hparam(key, default_config[key])

  # Update all hparams' keys if override_loaded_hparams=True
  if default_hparams.override_loaded_hparams:
    for key in default_config:
      if getattr(hparams, key) != default_config[key]:
        utils.print_out(
            "# Updating hparams.%s: %s -> %s" % (key, str(
                getattr(hparams, key)), str(default_config[key])))
        setattr(hparams, key, default_config[key])
  return hparams
def before_train(loaded_train_model, train_model, train_sess, global_step,
                 hparams, log_f):
    """Misc tasks to do before training."""
    stats = init_stats()
    info = {
        "train_ppl": 0.0,
        "speed": 0.0,
        "avg_step_time": 0.0,
        "avg_grad_norm": 0.0,
        "learning_rate":
        loaded_train_model.learning_rate.eval(session=train_sess)
    }
    start_train_time = time.time()
    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, info["learning_rate"], time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})

    return stats, info, start_train_time
Example #21
0
    def _get_learning_rate_warmup(self, hparams):
        """Get learning rate warmup."""
        warmup_steps = hparams.warmup_steps
        warmup_scheme = hparams.warmup_scheme
        utils.print_out(
            '  learning_rate={}, warmup_steps={}, warmup_scheme={}'.format(
                hparams.learning_rate, warmup_steps, warmup_scheme))

        # Apply inverse decay if global steps less than warmup steps.
        # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3)
        # When step < warmup_steps,
        #   learing_rate *= warmup_factor ** (warmup_steps - step)
        if warmup_scheme == 't2t':
            # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller
            warmup_factor = tf.exp(tf.log(0.01) / warmup_steps)
            inv_decay = warmup_factor**(tf.to_float(warmup_steps -
                                                    self.global_step))
        else:
            raise ValueError('Unknown warmup scheme {}'.format(warmup_scheme))

        return tf.cond(self.global_step < hparams.warmup_steps,
                       lambda: inv_decay * self.learning_rate,
                       lambda: self.learning_rate,
                       name='learning_rate_warump_cond')
Example #22
0
def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""):
  """Run main."""
  # Job
  jobid = flags.jobid
  num_workers = flags.num_workers
  utils.print_out("# Job id %d" % jobid)

  # Random
  random_seed = flags.random_seed
  if random_seed is not None and random_seed > 0:
    utils.print_out("# Set random seed to %d" % random_seed)
    random.seed(random_seed + jobid)
    np.random.seed(random_seed + jobid)

  ## Train / Decode
  out_dir = flags.out_dir
  if not tf.gfile.Exists(out_dir):
    tf.gfile.MakeDirs(out_dir)

  # Load hparams.
  hparams = create_or_load_hparams(
      out_dir, default_hparams, flags.hparams_path, save_hparams=(jobid == 0))

  if flags.inference_input_file:
    # Inference indices
    hparams.inference_indices = None
    if flags.inference_list:
      (hparams.inference_indices) = ([
          int(token) for token in flags.inference_list.split(",")
      ])

    # Inference
    trans_file = flags.inference_output_file
    ckpt = flags.ckpt
    if not ckpt:
      ckpt = tf.train.latest_checkpoint(out_dir)
    inference_fn(ckpt, flags.inference_input_file, trans_file, hparams,
                 num_workers, jobid)

    # Evaluation
    ref_file = flags.inference_ref_file
    if ref_file and tf.gfile.Exists(trans_file):
      for metric in hparams.metrics:
        score = evaluation_utils.evaluate(ref_file, trans_file, metric,
                                          hparams.subword_option)
        utils.print_out("  %s: %.1f" % (metric, score))
  else:
    # Train
    train_fn(hparams, target_session=target_session)
Example #23
0
    def _build_encoder(self, hparams):
        """Build an encoder."""
        num_layers = self.num_encoder_layers
        num_residual_layers = self.num_encoder_residual_layers

        iterator = self.iterator

        source = iterator.source
        # Make shape [max_time, batch_size].
        source = tf.transpose(source)

        with tf.variable_scope('encoder') as scope:
            dtype = scope.dtype
            # Look up embedding, emp_inp: [max_time, batch_size, num_units]
            encoder_emb_inp = tf.nn.embedding_lookup(self.embedding_encoder,
                                                     source)

            # encoder_outputs: [max_time, batch_size, num_units]
            if hparams.encoder_type == 'uni':
                utils.print_out(
                    '  num_layers={}, num_residual_layers={}'.format(
                        num_layers, num_residual_layers))
                cell = self._build_encoder_cell(hparams, num_layers,
                                                num_residual_layers)

                encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                    cell,
                    encoder_emb_inp,
                    dtype=dtype,
                    sequence_length=iterator.source_sequence_length,
                    time_major=True,
                    swap_memory=True)
            elif hparams.encoder_type == 'bi':
                num_bi_layers = int(num_layers / 2)
                num_bi_residual_layers = int(num_residual_layers / 2)
                utils.print_out(
                    '  num_bi_layers={}, num_bi_residual_layers={}'.format(
                        num_bi_layers, num_bi_residual_layers))

                encoder_outputs, bi_encoder_state = (
                    ### fw, bw bidirectional (_build_encoder_cell)
                    self._build_bidirectional_rnn(
                        inputs=encoder_emb_inp,
                        sequence_length=iterator.source_sequence_length,
                        dtype=dtype,
                        hparams=hparams,
                        num_bi_layers=num_bi_layers,
                        num_bi_residual_layers=num_bi_residual_layers))

                if num_bi_layers == 1:
                    encoder_state = bi_encoder_state
                else:
                    # alternatively concat forward and backward states
                    encoder_state = []
                    for layer_id in range(num_bi_layers):
                        encoder_state.append(
                            bi_encoder_state[0][layer_id])  # forward
                        encoder_state.append(
                            bi_encoder_state[1][layer_id])  # backward
                    encoder_state = tuple(encoder_state)
            else:
                raise ValueError('Unknown encoder_type {}'.format(
                    hparams.encoder_type))
        return encoder_outputs, encoder_state
Example #24
0
    def __init__(self,
                 hparams,
                 mode,
                 iterator,
                 source_vocab_table,
                 target_vocab_table,
                 reverse_target_vocab_table=None,
                 scope=None,
                 extra_args=None,
                 trie=None):
        """Create the model.

    Args:
      hparams: Hyperparameter configurations.
      mode: TRAIN | EVAL | INFER
      iterator: Dataset Iterator that feeds data.
      source_vocab_table: Lookup table mapping source words to ids.
      target_vocab_table: Lookup table mapping target words to ids.
      reverse_target_vocab_table: Lookup table mapping ids to target words. Only
        required in INFER mode. Defaults to None.
      scope: scope of the model.
      extra_args: model_helper.ExtraArgs, for passing customizable functions.
      trie: pygtrie.Trie to decode into

    """
        assert isinstance(iterator, iterator_utils.BatchedInput)
        self.iterator = iterator
        self.mode = mode
        self.src_vocab_table = source_vocab_table
        self.tgt_vocab_table = target_vocab_table

        self.src_vocab_size = hparams.src_vocab_size
        self.tgt_vocab_size = hparams.tgt_vocab_size
        self.num_gpus = hparams.num_gpus
        self.reverse_target_vocab_table = reverse_target_vocab_table
        # extra_args: to make it flexible for adding external customizable code
        self.single_cell_fn = None
        if extra_args:
            self.single_cell_fn = extra_args.single_cell_fn

        # Set num layers
        self.num_encoder_layers = hparams.num_encoder_layers
        self.num_decoder_layers = hparams.num_decoder_layers
        assert self.num_encoder_layers
        assert self.num_decoder_layers

        self.trie = trie

        # Set num residual layers
        if hasattr(hparams,
                   'num_residual_layers'):  # compatible common_test_utils
            self.num_encoder_residual_layers = hparams.num_residual_layers
            self.num_decoder_residual_layers = hparams.num_residual_layers
        else:
            self.num_encoder_residual_layers = hparams.num_encoder_residual_layers
            self.num_decoder_residual_layers = hparams.num_decoder_residual_layers

        # Initializer
        initializer = model_helper.get_initializer(hparams.init_op,
                                                   hparams.random_seed,
                                                   hparams.init_weight)
        tf.get_variable_scope().set_initializer(initializer)

        # Embeddings
        self.init_embeddings(hparams, scope)
        self.batch_size = tf.size(self.iterator.source_sequence_length)

        # Projection
        with tf.variable_scope(scope or 'build_network', reuse=tf.AUTO_REUSE):
            with tf.variable_scope('decoder/output_projection'):
                self.output_layer = layers_core.Dense(hparams.tgt_vocab_size,
                                                      use_bias=False,
                                                      name='output_projection')

        if hparams.use_rl:
            # Create environment function
            self._environment_reward_fn = (
                environment_client.make_environment_reward_fn(
                    hparams.environment_server, mode=hparams.environment_mode))

        ## Train graph
        res = self.build_graph(hparams, scope=scope)

        (self.loss, self.rewards, self.logits, self.final_context_state,
         self.sample_id, self.sample_words, self.sample_strings,
         train_summaries) = res
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.word_count = tf.reduce_sum(
                self.iterator.source_sequence_length) + tf.reduce_sum(
                    self.iterator.target_sequence_length)

        if self.mode != tf.contrib.learn.ModeKeys.INFER:
            ## Count the number of predicted words for compute ppl.
            self.predict_count = tf.reduce_sum(
                self.iterator.target_sequence_length)

        self.global_step = tf.Variable(0, trainable=False)
        params = tf.trainable_variables()

        # Gradients and SGD update operation for training the model.
        # Arrage for the embedding vars to appear at the beginning.
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.learning_rate = tf.constant(hparams.learning_rate)
            # warm-up
            self.learning_rate = self._get_learning_rate_warmup(hparams)
            # decay
            self.learning_rate = self._get_learning_rate_decay(hparams)

            # Optimizer
            if hparams.optimizer == 'sgd':
                opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            elif hparams.optimizer == 'adam':
                opt = tf.train.AdamOptimizer(self.learning_rate)
            elif hparams.optimizer == 'rmsprop':
                opt = tf.train.RMSPropOptimizer(self.learning_rate)
            elif hparams.optimizer == 'adadelta':
                opt = tf.train.AdadeltaOptimizer(self.learning_rate)

            # Gradients
            ## http://blog.naver.com/PostView.nhn?blogId=atelierjpro&logNo=220978930368&categoryNo=0&parentCategoryNo=0&viewDate=&currentPage=1&postListTopCurrentPage=1&from=postView
            gradients = tf.gradients(self.loss,
                                     params,
                                     colocate_gradients_with_ops=hparams.
                                     colocate_gradients_with_ops)

            (clipped_gradients, gradients_norm_summary,
             gradients_norm) = model_helper.gradient_clip(
                 gradients, max_gradient_norm=hparams.max_gradient_norm)
            self.gradients_norm = gradients_norm

            self.update = opt.apply_gradients(zip(clipped_gradients, params),
                                              global_step=self.global_step)

            # Summary
            self.train_summary = tf.summary.merge([
                tf.summary.scalar('lr', self.learning_rate),
                tf.summary.scalar('train_loss', self.loss),
            ] + train_summaries + gradients_norm_summary)

        if self.mode == tf.contrib.learn.ModeKeys.INFER:
            self.infer_summary = self._get_infer_summary(hparams)

        # Saver
        self.saver = OptimisticRestoreSaver(max_to_keep=hparams.num_keep_ckpts,
                                            init_uninitialized_variables=True)

        # Print trainable variables
        utils.print_out('# Trainable variables')
        for param in params:
            utils.print_out('  {}, {}, {}'.format(param.name,
                                                  param.get_shape(),
                                                  param.op.device))
Example #25
0
    def build_graph(self, hparams, scope=None):
        """Subclass must implement this method.

    Creates a sequence-to-sequence model with dynamic RNN decoder API.
    Args:
      hparams: Hyperparameter configurations.
      scope: VariableScope for the created subgraph; default "dynamic_seq2seq".

    Returns:
      A tuple of the form (logits, loss, final_context_state),
      where:
        logits: float32 Tensor [batch_size x num_decoder_symbols].
        loss: the total loss / batch_size.
        final_context_state: The final state of decoder RNN.

    Raises:
      ValueError: if encoder_type differs from mono and bi, or
        attention_option is not (luong | scaled_luong |
        bahdanau | normed_bahdanau).
    """
        utils.print_out('# creating {} graph ...'.format(self.mode))
        dtype = tf.float32
        train_summaries = []

        with tf.variable_scope(scope or 'dynamic_seq2seq',
                               dtype=dtype,
                               reuse=tf.AUTO_REUSE):
            ## Context
            if hparams.ctx is not None:
                vector_size = hparams.num_units
                if (hparams.encoder_type == 'bi'
                        and hparams.context_feed == 'encoder_output'):
                    vector_size *= 2

                ### context vector 생성
                context_vector = context_encoder.get_context_vector(
                    self.mode, self.iterator, hparams, vector_size=vector_size)

            ### Encoder 생성
            encoder_outputs, encoder_state = self._build_encoder(hparams)

            ## Feed the Context to encoder_output, encoder_states
            if hparams.ctx is not None:
                encoder_outputs, encoder_state = context_encoder.feed(
                    context_vector, encoder_outputs, encoder_state, hparams)

            ## Decoder 생성
            logits, sample_ids, final_context_state = self._build_decoder(
                encoder_outputs, encoder_state, hparams)

            sample_words = self.reverse_target_vocab_table.lookup(
                tf.to_int64(sample_ids))

            # Make output shape = [batch_size, time] or [beam_width, batch_size, time]
            # when using beam search.
            ## sample id to string
            sample_ids = tf.transpose(sample_ids)
            sample_words = tf.transpose(sample_words)
            sample_strings = tf.py_func(self.tokens_to_strings,
                                        [sample_words, hparams.eos],
                                        (tf.string), 'TokensToStrings')

            if hparams.server_mode:
                rewards = self.iterator.weights
            elif hparams.use_rl:
                # Compute rewards when in TRAIN or EVAL mode
                iterator = self.iterator
                doc_ids = iterator.annotation

                ## reaward 계산
                rewards, _ = self.compute_rewards(questions=sample_strings,
                                                  doc_ids=doc_ids)

                train_summaries.append(
                    tf.summary.scalar('train_avg_reward',
                                      tf.reduce_mean(rewards)))
            else:
                # tf does only accepts tensors as returned values
                rewards = tf.constant([1.0])

            ## Loss
            if self.mode != tf.contrib.learn.ModeKeys.INFER:

                with tf.device(
                        model_helper.get_device_str(
                            self.num_encoder_layers - 1, self.num_gpus)):
                    # encoder_outputs.shape = [max_time, batch_size, embeddings_dim]
                    question_embeddings = tf.reduce_mean(encoder_outputs, 0)
                    loss = self._compute_loss(
                        hparams=hparams,
                        logits=logits,
                        sample_ids=sample_ids,
                        sample_words=sample_words,
                        rewards=rewards,
                        question_embeddings=question_embeddings,
                        train_summaries=train_summaries)
            else:
                loss = None

            return (loss, rewards, logits, final_context_state, sample_ids,
                    sample_words, sample_strings, train_summaries)
Example #26
0
def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        subword_option,
                        beam_width,
                        tgt_eos,
                        hparams,
                        num_translations_per_input=1,
                        decode=True):
  """Decode a test set and compute a score according to the metrics.

    Args:
      name: name of the set being evaluated.
      model: model
      sess: session
      trans_file: name of the file that the translations will be written to.
      ref_file: ground-truth file to compare against the generated translations.
      metrics: a list of metrics that the model will be evaluated on. Valid
          options are: "f1", "bleu", "rouge", and "accuracy".
      subword_options: either "bpe", "spm", or "".
      beam_width: beam search width.
      tgt_eos: end of sentence token to the target translations.
      hparams: parameters object
      num_translations_per_input: number of translations to be generated per
          input. It is upper-bounded by beam_width
      decode: if True, generate translations using the model. Otherwise, compute
          metrics using the translations in the trans_file.
    Returns:

  """

  all_rewards = []

  # Decode
  if decode:
    utils.print_out("  decoding to output %s." % trans_file)

    start_time = time.time()
    start_time_step = time.time()
    step = 0
    num_sentences = 0
    with tf.gfile.GFile(trans_file, mode="wb") as trans_f:

      num_translations_per_input = max(
          min(num_translations_per_input, beam_width), 1)
      while True:
        try:
          _, _, _, nmt_outputs, rewards = model.infer(sess)

          all_rewards.extend(rewards.flatten().tolist())

          if beam_width == 0:
            nmt_outputs = np.expand_dims(nmt_outputs, 0)

          batch_size = nmt_outputs.shape[1]
          num_sentences += batch_size

          for sent_id in range(batch_size):
            for beam_id in range(num_translations_per_input):
              translation = get_translation(
                  nmt_outputs[beam_id],
                  sent_id,
                  tgt_eos=tgt_eos,
                  subword_option=subword_option)
              trans_f.write((translation + "\n").decode("utf-8"))

          if step % hparams.steps_per_stats == 0:
            # print_time does not print decimal places for time.
            utils.print_out("  external evaluation, step %d, time %.2f" %
                            (step, time.time() - start_time_step))
          step += 1
          start_time_step = time.time()
        except tf.errors.OutOfRangeError:
          utils.print_time(
              "  done, num sentences %d, num translations per input %d" %
              (num_sentences, num_translations_per_input), start_time)
          break

  # Evaluation
  evaluation_scores = {}

  # We treat F1 scores differently because they don't need ground truth
  # sentences and they are expensive to compute due to environment calls.
  if "f1" in metrics:
    f1_score = np.mean(all_rewards)
    evaluation_scores["f1"] = f1_score
    utils.print_out("  f1 %s: %.1f" % (name, f1_score))

  for metric in metrics:
    if metric != "f1" and ref_file:
      if not tf.gfile.Exists(trans_file):
        raise IOException("%s: translation file not found" % trans_file)
      score = evaluation_utils.evaluate(
          ref_file, trans_file, metric, subword_option=subword_option)
      evaluation_scores[metric] = score
      utils.print_out("  %s %s: %.1f" % (metric, name, score))

  return evaluation_scores
def avg_checkpoints(model_dir, num_last_checkpoints, global_step,
                    global_step_name):
  """Average the last N checkpoints in the model_dir."""
  checkpoint_state = tf.train.get_checkpoint_state(model_dir)
  if not checkpoint_state:
    utils.print_out("# No checkpoint file found in directory: %s" % model_dir)
    return None

  # Checkpoints are ordered from oldest to newest.
  checkpoints = (
      checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:])

  if len(checkpoints) < num_last_checkpoints:
    utils.print_out(
        "# Skipping averaging checkpoints because not enough checkpoints is "
        "avaliable.")
    return None

  avg_model_dir = os.path.join(model_dir, "avg_checkpoints")
  if not tf.gfile.Exists(avg_model_dir):
    utils.print_out(
        "# Creating new directory %s for saving averaged checkpoints." %
        avg_model_dir)
    tf.gfile.MakeDirs(avg_model_dir)

  utils.print_out("# Reading and averaging variables in checkpoints:")
  var_list = tf.contrib.framework.list_variables(checkpoints[0])
  var_values, var_dtypes = {}, {}
  for (name, shape) in var_list:
    if name != global_step_name:
      var_values[name] = np.zeros(shape)

  for checkpoint in checkpoints:
    utils.print_out("    %s" % checkpoint)
    reader = tf.contrib.framework.load_checkpoint(checkpoint)
    for name in var_values:
      tensor = reader.get_tensor(name)
      var_dtypes[name] = tensor.dtype
      var_values[name] += tensor

  for name in var_values:
    var_values[name] /= len(checkpoints)

  # Build a graph with same variables in the checkpoints, and save the averaged
  # variables into the avg_model_dir.
  with tf.Graph().as_default():
    tf_vars = [
        tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
        for v in var_values
    ]

    placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
    assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
    global_step_var = tf.Variable(
        global_step, name=global_step_name, trainable=False)
    saver = tf.train.Saver(tf.all_variables())

    with tf.Session() as sess:
      sess.run(tf.initialize_all_variables())
      for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                             six.iteritems(var_values)):
        sess.run(assign_op, {p: value})

      # Use the built saver to save the averaged checkpoint. Only keep 1
      # checkpoint and the best checkpoint will be moved to avg_best_metric_dir.
      saver.save(sess, os.path.join(avg_model_dir, "translate.ckpt"))

  return avg_model_dir
Example #28
0
def extend_hparams(hparams):
  """Extend training hparams."""
  assert hparams.num_encoder_layers and hparams.num_decoder_layers
  if hparams.num_encoder_layers != hparams.num_decoder_layers:
    hparams.pass_hidden_state = False
    utils.print_out("Num encoder layer %d is different from num decoder layer"
                    " %d, so set pass_hidden_state to False" %
                    (hparams.num_encoder_layers, hparams.num_decoder_layers))

  # Sanity checks
  if hparams.encoder_type == "bi" and hparams.num_encoder_layers % 2 != 0:
    raise ValueError("For bi, num_encoder_layers %d should be even" %
                     hparams.num_encoder_layers)
  if (hparams.attention_architecture in ["gnmt"] and
      hparams.num_encoder_layers < 2):
    raise ValueError(
        "For gnmt attention architecture, "
        "num_encoder_layers %d should be >= 2" % hparams.num_encoder_layers)

  # Set residual layers
  num_encoder_residual_layers = 0
  num_decoder_residual_layers = 0
  if hparams.residual:
    if hparams.num_encoder_layers > 1:
      num_encoder_residual_layers = hparams.num_encoder_layers - 1
    if hparams.num_decoder_layers > 1:
      num_decoder_residual_layers = hparams.num_decoder_layers - 1

    if hparams.encoder_type == "gnmt":
      # The first unidirectional layer (after the bi-directional layer) in
      # the GNMT encoder can't have residual connection due to the input is
      # the concatenation of fw_cell and bw_cell's outputs.
      num_encoder_residual_layers = hparams.num_encoder_layers - 2

      # Compatible for GNMT models
      if hparams.num_encoder_layers == hparams.num_decoder_layers:
        num_decoder_residual_layers = num_encoder_residual_layers
  hparams.add_hparam("num_encoder_residual_layers", num_encoder_residual_layers)
  hparams.add_hparam("num_decoder_residual_layers", num_decoder_residual_layers)

  if hparams.subword_option and hparams.subword_option not in ["spm", "bpe"]:
    raise ValueError("subword option must be either spm, or bpe")

  # Context sanity checks
  # Make sure if one is set, everything should be set
  if hparams.ctx is not None and (not hparams.context_vector or
                                  not hparams.context_feed):
    raise ValueError("If ctx file is provided, "
                     "both context_vector and context_feed have to be set")
  if hparams.ctx is None and (hparams.context_vector or hparams.context_feed):
    raise ValueError("ctx must be provided ")
  if ((hparams.context_vector == "append" or hparams.context_feed == "append")
      and (hparams.context_vector != hparams.context_feed)):
    raise ValueError("context_vector and context_feed must be set to append")
  if (hparams.context_vector == "last_state" and
      hparams.context_feed != "decoder_hidden_state"):
    raise ValueError("context_feed must be set to decoder_hidden_state "
                     "when using last_state as context_vector")
  if (hparams.context_vector == "bilstm_all" and
      hparams.context_feed != "encoder_output"):
    raise ValueError("context_feed must be set to encoder_output "
                     "when using bilstm_all as context_vector")

  # Flags
  utils.print_out("# hparams:")
  utils.print_out("  src=%s" % hparams.src)
  utils.print_out("  tgt=%s" % hparams.tgt)
  utils.print_out("  train_prefix=%s" % hparams.train_prefix)
  utils.print_out("  dev_prefix=%s" % hparams.dev_prefix)
  utils.print_out("  test_prefix=%s" % hparams.test_prefix)
  utils.print_out("  train_annotations=%s" % hparams.train_annotations)
  utils.print_out("  dev_annotations=%s" % hparams.dev_annotations)
  utils.print_out("  test_annotations=%s" % hparams.test_annotations)
  utils.print_out("  out_dir=%s" % hparams.out_dir)

  ## Vocab
  # Get vocab file names first
  if hparams.vocab_prefix:
    src_vocab_file = hparams.vocab_prefix + "." + hparams.src
    tgt_vocab_file = hparams.vocab_prefix + "." + hparams.tgt
  else:
    raise ValueError("hparams.vocab_prefix must be provided.")

  # Source vocab
  src_vocab_size, src_vocab_file = vocab_utils.check_vocab(
      src_vocab_file,
      hparams.out_dir,
      check_special_token=hparams.check_special_token,
      sos=hparams.sos,
      eos=hparams.eos,
      unk=vocab_utils.UNK)

  # Target vocab
  if hparams.share_vocab:
    utils.print_out("  using source vocab for target")
    tgt_vocab_file = src_vocab_file
    tgt_vocab_size = src_vocab_size
  else:
    tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab(
        tgt_vocab_file,
        hparams.out_dir,
        check_special_token=hparams.check_special_token,
        sos=hparams.sos,
        eos=hparams.eos,
        unk=vocab_utils.UNK)
  hparams.add_hparam("src_vocab_size", src_vocab_size)
  hparams.add_hparam("tgt_vocab_size", tgt_vocab_size)
  hparams.add_hparam("src_vocab_file", src_vocab_file)
  hparams.add_hparam("tgt_vocab_file", tgt_vocab_file)

  # Pretrained Embeddings:
  hparams.add_hparam("src_embed_file", "")
  hparams.add_hparam("tgt_embed_file", "")
  if hparams.embed_prefix:
    src_embed_file = hparams.embed_prefix + "." + hparams.src
    tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt

    if tf.gfile.Exists(src_embed_file):
      hparams.src_embed_file = src_embed_file

    if tf.gfile.Exists(tgt_embed_file):
      hparams.tgt_embed_file = tgt_embed_file

  # Check out_dir
  if not tf.gfile.Exists(hparams.out_dir):
    utils.print_out("# Creating output directory %s ..." % hparams.out_dir)
    tf.gfile.MakeDirs(hparams.out_dir)

  # Evaluation
  for metric in hparams.metrics:
    hparams.add_hparam("best_" + metric, 0)  # larger is better
    best_metric_dir = os.path.join(hparams.out_dir, "best_" + metric)
    hparams.add_hparam("best_" + metric + "_dir", best_metric_dir)
    tf.gfile.MakeDirs(best_metric_dir)

    if hparams.avg_ckpts:
      hparams.add_hparam("avg_best_" + metric, 0)  # larger is better
      best_metric_dir = os.path.join(hparams.out_dir, "avg_best_" + metric)
      hparams.add_hparam("avg_best_" + metric + "_dir", best_metric_dir)
      tf.gfile.MakeDirs(best_metric_dir)

  return hparams
def _single_cell(unit_type,
                 num_units,
                 forget_bias,
                 dropout,
                 mode,
                 residual_connection=False,
                 device_str=None,
                 residual_fn=None):
  """Create an instance of a single RNN cell."""
  # dropout (= 1 - keep_prob) is set to 0 during eval and infer
  dropout = dropout if mode == tf.contrib.learn.ModeKeys.TRAIN else 0.0

  # Cell Type
  if unit_type == "lstm":
    utils.print_out("  LSTM, forget_bias=%g" % forget_bias, new_line=False)
    single_cell = tf.contrib.rnn.BasicLSTMCell(
        num_units, forget_bias=forget_bias)
  elif unit_type == "gru":
    utils.print_out("  GRU", new_line=False)
    single_cell = tf.contrib.rnn.GRUCell(num_units)
  elif unit_type == "layer_norm_lstm":
    utils.print_out(
        "  Layer Normalized LSTM, forget_bias=%g" % forget_bias, new_line=False)
    single_cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
        num_units, forget_bias=forget_bias, layer_norm=True)
  elif unit_type == "nas":
    utils.print_out("  NASCell", new_line=False)
    single_cell = tf.contrib.rnn.NASCell(num_units)
  else:
    raise ValueError("Unknown unit type %s!" % unit_type)

  # Dropout (= 1 - keep_prob)
  if dropout > 0.0:
    single_cell = tf.contrib.rnn.DropoutWrapper(
        cell=single_cell, input_keep_prob=(1.0 - dropout))
    utils.print_out(
        "  %s, dropout=%g " % (type(single_cell).__name__, dropout),
        new_line=False)

  # Residual
  if residual_connection:
    single_cell = tf.contrib.rnn.ResidualWrapper(
        single_cell, residual_fn=residual_fn)
    utils.print_out("  %s" % type(single_cell).__name__, new_line=False)

  # Device Wrapper
  if device_str:
    single_cell = tf.contrib.rnn.DeviceWrapper(single_cell, device_str)
    utils.print_out(
        "  %s, device=%s" % (type(single_cell).__name__, device_str),
        new_line=False)

  return single_cell
def create_emb_for_encoder_and_decoder(share_vocab,
                                       src_vocab_size,
                                       tgt_vocab_size,
                                       src_embed_size,
                                       tgt_embed_size,
                                       dtype=tf.float32,
                                       num_partitions=0,
                                       src_vocab_file=None,
                                       tgt_vocab_file=None,
                                       src_embed_file=None,
                                       tgt_embed_file=None,
                                       scope=None):
  """Create embedding matrix for both encoder and decoder.

  Args:
    share_vocab: A boolean. Whether to share embedding matrix for both
      encoder and decoder.
    src_vocab_size: An integer. The source vocab size.
    tgt_vocab_size: An integer. The target vocab size.
    src_embed_size: An integer. The embedding dimension for the encoder's
      embedding.
    tgt_embed_size: An integer. The embedding dimension for the decoder's
      embedding.
    dtype: dtype of the embedding matrix. Default to float32.
    num_partitions: number of partitions used for the embedding vars.
    scope: VariableScope for the created subgraph. Default to "embedding".

  Returns:
    embedding_encoder: Encoder's embedding matrix.
    embedding_decoder: Decoder's embedding matrix.

  Raises:
    ValueError: if use share_vocab but source and target have different vocab
      size.
  """

  if num_partitions <= 1:
    partitioner = None
  else:
    # Note: num_partitions > 1 is required for distributed training due to
    # embedding_lookup tries to colocate single partition-ed embedding variable
    # with lookup ops. This may cause embedding variables being placed on worker
    # jobs.
    partitioner = tf.fixed_size_partitioner(num_partitions)

  if (src_embed_file or tgt_embed_file) and partitioner:
    raise ValueError(
        "Can't set num_partitions > 1 when using pretrained embedding")

  with tf.variable_scope(
      scope or "embeddings",
      dtype=dtype,
      partitioner=partitioner,
      reuse=tf.AUTO_REUSE) as scope:
    # Share embedding
    if share_vocab:
      if src_vocab_size != tgt_vocab_size:
        raise ValueError("Share embedding but different src/tgt vocab sizes"
                         " %d vs. %d" % (src_vocab_size, tgt_vocab_size))
      assert src_embed_size == tgt_embed_size
      utils.print_out("# Use the same embedding for source and target")
      vocab_file = src_vocab_file or tgt_vocab_file
      embed_file = src_embed_file or tgt_embed_file

      embedding_encoder = _create_or_load_embed("embedding_share", vocab_file,
                                                embed_file, src_vocab_size,
                                                src_embed_size, dtype)
      embedding_decoder = embedding_encoder
    else:
      with tf.variable_scope(
          "encoder", partitioner=partitioner, reuse=tf.AUTO_REUSE):
        embedding_encoder = _create_or_load_embed(
            "embedding_encoder", src_vocab_file, src_embed_file, src_vocab_size,
            src_embed_size, dtype)

      with tf.variable_scope(
          "decoder", partitioner=partitioner, reuse=tf.AUTO_REUSE):
        embedding_decoder = _create_or_load_embed(
            "embedding_decoder", tgt_vocab_file, tgt_embed_file, tgt_vocab_size,
            tgt_embed_size, dtype)

  return embedding_encoder, embedding_decoder