コード例 #1
0
def create_test_iterator(hparams, mode, trie_excludes=None):
    """Create test iterator."""
    src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant([hparams.eos, "a", "b", "c", "d"]))
    tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"])
    tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping)

    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor(
        tgt_vocab_mapping)

    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["a a b b c", "a b b"]))

    ctx_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c b c b a", "b c b a"]))

    trie_excludes = trie_excludes or []
    trie_excludes = " {} ".format(hparams.eos).join(trie_excludes)
    tex_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant([trie_excludes, trie_excludes]))

    if mode != tf.contrib.learn.ModeKeys.INFER:
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["a b c b c", "a b c b"]))
        return (iterator_utils.get_iterator(hparams=hparams,
                                            src_dataset=src_dataset,
                                            tgt_dataset=tgt_dataset,
                                            ctx_dataset=ctx_dataset,
                                            annot_dataset=None,
                                            src_vocab_table=src_vocab_table,
                                            tgt_vocab_table=tgt_vocab_table,
                                            batch_size=hparams.batch_size,
                                            sos=hparams.sos,
                                            eos=hparams.eos,
                                            random_seed=hparams.random_seed,
                                            num_buckets=hparams.num_buckets),
                src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table)
    else:
        return (iterator_utils.get_infer_iterator(
            hparams=hparams,
            src_dataset=src_dataset,
            ctx_dataset=ctx_dataset,
            annot_dataset=None,
            trie_exclude_dataset=tex_dataset,
            src_vocab_table=src_vocab_table,
            tgt_vocab_table=tgt_vocab_table,
            eos=hparams.eos,
            batch_size=hparams.batch_size), src_vocab_table, tgt_vocab_table,
                reverse_tgt_vocab_table)
コード例 #2
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       graph=None,
                       extra_args=None,
                       trie=None,
                       use_placeholders=False):
  """Create train graph, model, and iterator."""
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  if not graph:
    graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "train"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)

    annot_placeholder = None
    src_placeholder = None
    tgt_placeholder = None
    annot_dataset = None
    ctx_dataset = None
    if use_placeholders:
      src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
      src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)

      tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
      tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder)

      if hparams.use_rl:
        annot_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        annot_dataset = tf.data.Dataset.from_tensor_slices(annot_placeholder)
    else:
      src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
      tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
      ctx_file = None
      if hparams.ctx is not None:
        ctx_file = "%s.%s" % (hparams.train_prefix, hparams.ctx)

      src_dataset = tf.data.TextLineDataset(src_file)
      tgt_dataset = tf.data.TextLineDataset(tgt_file)

      if hparams.train_annotations is not None:
        annot_dataset = tf.data.TextLineDataset(hparams.train_annotations)

      if ctx_file is not None:
        ctx_dataset = tf.data.TextLineDataset(ctx_file)

    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        ctx_dataset=ctx_dataset,
        annot_dataset=annot_dataset,
        batch_size=hparams.batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len,
        tgt_max_len=hparams.tgt_max_len,
        skip_count=skip_count_placeholder,
        num_shards=num_workers,
        shard_index=jobid)

    # Note: One can set model_device_fn to
    # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
    model_device_fn = None
    if extra_args:
      model_device_fn = extra_args.model_device_fn
    with tf.device(model_device_fn):
      model = model_creator(
          hparams=hparams,
          iterator=iterator,
          mode=tf.contrib.learn.ModeKeys.TRAIN,
          source_vocab_table=src_vocab_table,
          target_vocab_table=tgt_vocab_table,
          reverse_target_vocab_table=reverse_tgt_vocab_table,
          scope=scope,
          extra_args=extra_args,
          trie=trie)

  return TrainModel(
      graph=graph,
      model=model,
      iterator=iterator,
      src_placeholder=src_placeholder,
      tgt_placeholder=tgt_placeholder,
      annot_placeholder=annot_placeholder,
      skip_count_placeholder=skip_count_placeholder)
コード例 #3
0
def create_eval_model(model_creator,
                      hparams,
                      scope=None,
                      graph=None,
                      extra_args=None,
                      trie=None):
  """Create train graph, model, src/tgt file holders, and iterator."""
  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  if not graph:
    graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "eval"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)
    src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
    tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

    ctx_file_placeholder = None
    if hparams.ctx is not None:
      ctx_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

    annot_file_placeholder = None
    if hparams.dev_annotations is not None:
      annot_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

    src_dataset = tf.data.TextLineDataset(src_file_placeholder)
    tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)

    ctx_dataset = None
    if ctx_file_placeholder is not None:
      ctx_dataset = tf.data.TextLineDataset(ctx_file_placeholder)

    annot_dataset = None
    if annot_file_placeholder is not None:
      annot_dataset = tf.data.TextLineDataset(annot_file_placeholder)

    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        batch_size=hparams.batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        ctx_dataset=ctx_dataset,
        annot_dataset=annot_dataset,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len_infer,
        tgt_max_len=hparams.tgt_max_len_infer)
    model = model_creator(
        hparams=hparams,
        iterator=iterator,
        mode=tf.contrib.learn.ModeKeys.EVAL,
        source_vocab_table=src_vocab_table,
        target_vocab_table=tgt_vocab_table,
        reverse_target_vocab_table=reverse_tgt_vocab_table,
        scope=scope,
        extra_args=extra_args,
        trie=trie)

  return EvalModel(
      graph=graph,
      model=model,
      src_file_placeholder=src_file_placeholder,
      tgt_file_placeholder=tgt_file_placeholder,
      ctx_file_placeholder=ctx_file_placeholder,
      annot_file_placeholder=annot_file_placeholder,
      iterator=iterator)
コード例 #4
0
def create_train_model_for_server(model_creator,
                                  hparams,
                                  scope=None,
                                  num_workers=1,
                                  jobid=0,
                                  graph=None,
                                  extra_args=None,
                                  trie=None):
  """Create graph, model, and iterator when running the NMT in server mode.

  This is different from the standard training model, because the input arrives
  via RPC and thus has to be fed using placeholders."""
  assert hparams.num_buckets == 1, "No bucketing when in server mode."

  src_vocab_file = hparams.src_vocab_file
  tgt_vocab_file = hparams.tgt_vocab_file

  if not graph:
    graph = tf.Graph()

  with graph.as_default(), tf.container(scope or "train"):
    src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
        src_vocab_file, tgt_vocab_file, hparams.share_vocab)
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file(
        tgt_vocab_file, default_value=vocab_utils.UNK)
    skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

    src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder)

    tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder)

    wgt_placeholder = tf.placeholder(shape=[None], dtype=tf.float32)
    wgt_dataset = tf.data.Dataset.from_tensor_slices(wgt_placeholder)

    ctx_placeholder = None
    if hparams.ctx is not None:
      ctx_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    ctx_dataset = None
    if ctx_placeholder is not None:
      ctx_dataset = tf.data.Dataset.from_tensor_slices(ctx_placeholder)

    batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        wgt_dataset=wgt_dataset,
        ctx_dataset=ctx_dataset,
        annot_dataset=None,
        batch_size=batch_size_placeholder,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=hparams.src_max_len,
        tgt_max_len=hparams.tgt_max_len,
        skip_count=skip_count_placeholder,
        num_shards=num_workers,
        shard_index=jobid)

    # Note: One can set model_device_fn to
    # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
    model_device_fn = None
    if extra_args:
      model_device_fn = extra_args.model_device_fn
    with tf.device(model_device_fn):
      model = model_creator(
          hparams=hparams,
          iterator=iterator,
          mode=tf.contrib.learn.ModeKeys.TRAIN,
          source_vocab_table=src_vocab_table,
          target_vocab_table=tgt_vocab_table,
          reverse_target_vocab_table=reverse_tgt_vocab_table,
          scope=scope,
          extra_args=extra_args,
          trie=trie)

  return TrainModelForServer(
      graph=graph,
      model=model,
      iterator=iterator,
      src_placeholder=src_placeholder,
      tgt_placeholder=tgt_placeholder,
      wgt_placeholder=wgt_placeholder,
      batch_size_placeholder=batch_size_placeholder,
      skip_count_placeholder=skip_count_placeholder)
コード例 #5
0
  def testGetIteratorWithSkipCount(self):
    tf.set_random_seed(1)
    tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c a", "c c a", "d", "f e a g"]))
    tgt_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["b c", "a b", "", "c c"]))
    hparams = tf.contrib.training.HParams(
        random_seed=1,
        num_buckets=5,
        eos="eos",
        sos="sos",
        context_feed="",
        server_mode=False)
    batch_size = 2
    src_max_len = 3
    skip_count = tf.placeholder(shape=(), dtype=tf.int64)
    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        batch_size=batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=src_max_len,
        skip_count=skip_count,
        reshuffle_each_iteration=False)
    table_initializer = tf.tables_initializer()
    source = iterator.source
    target_input = iterator.target_input
    target_output = iterator.target_output
    src_seq_len = iterator.source_sequence_length
    tgt_seq_len = iterator.target_sequence_length
    self.assertEqual([None, None], source.shape.as_list())
    self.assertEqual([None, None], target_input.shape.as_list())
    self.assertEqual([None, None], target_output.shape.as_list())
    self.assertEqual([None], src_seq_len.shape.as_list())
    self.assertEqual([None], tgt_seq_len.shape.as_list())
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer, feed_dict={skip_count: 3})

      (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
          sess.run((source, src_seq_len, target_input, target_output,
                    tgt_seq_len)))
      self.assertAllEqual(
          [[-1, -1, 0]],  # "f" == unknown, "e" == unknown, a
          source_v)
      self.assertAllEqual([3], src_len_v)
      self.assertAllEqual(
          [[4, 2, 2]],  # sos c c
          target_input_v)
      self.assertAllEqual(
          [[2, 2, 3]],  # c c eos
          target_output_v)
      self.assertAllEqual([3], tgt_len_v)

      with self.assertRaisesOpError("End of sequence"):
        sess.run(source)

      # Re-init iterator with skip_count=0.
      sess.run(iterator.initializer, feed_dict={skip_count: 0})

      (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
          sess.run((source, src_seq_len, target_input, target_output,
                    tgt_seq_len)))
      self.assertAllEqual(
          [
              [2, 0, 3],  # c a eos -- eos is padding
              [-1, -1, 0]  # "f" == unknown, "e" == unknown, a
          ],
          source_v)
      self.assertAllEqual([2, 3], src_len_v)
      self.assertAllEqual(
          [
              [4, 1, 2],  # sos b c
              [4, 2, 2],  # sos c c
          ],
          target_input_v)
      self.assertAllEqual(
          [
              [1, 2, 3],  # b c eos
              [2, 2, 3],  # c c eos
          ],
          target_output_v)
      self.assertAllEqual([3, 3], tgt_len_v)

      (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
          sess.run((source, src_seq_len, target_input, target_output,
                    tgt_seq_len)))
      self.assertAllEqual(
          [[2, 2, 0]],  # c c a
          source_v)
      self.assertAllEqual([3], src_len_v)
      self.assertAllEqual(
          [[4, 0, 1]],  # sos a b
          target_input_v)
      self.assertAllEqual(
          [[0, 1, 3]],  # a b eos
          target_output_v)
      self.assertAllEqual([3], tgt_len_v)

      with self.assertRaisesOpError("End of sequence"):
        sess.run(source)
コード例 #6
0
  def testGetIteratorWithAnnotations(self):
    tf.set_random_seed(1)
    tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["f e a g", "c c a", "d", "c a"]))
    tgt_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c c", "a b", "", "b c"]))
    annot_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["1\t1", "2\t1", "3\t1", "4\t1"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        num_buckets=5,
        eos="eos",
        sos="sos",
        context_feed="",
        server_mode=False)
    batch_size = 2
    src_max_len = 3
    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        annot_dataset=annot_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        batch_size=batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=src_max_len,
        reshuffle_each_iteration=False)
    table_initializer = tf.tables_initializer()
    source = iterator.source
    target_input = iterator.target_input
    target_output = iterator.target_output
    src_seq_len = iterator.source_sequence_length
    tgt_seq_len = iterator.target_sequence_length
    annotation = iterator.annotation
    self.assertEqual([None, None], source.shape.as_list())
    self.assertEqual([None, None], target_input.shape.as_list())
    self.assertEqual([None, None], target_output.shape.as_list())
    self.assertEqual([None], src_seq_len.shape.as_list())
    self.assertEqual([None], tgt_seq_len.shape.as_list())
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)

      # First batch.
      (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v,
       annot_v) = (
           sess.run((source, src_seq_len, target_input, target_output,
                     tgt_seq_len, annotation)))
      self.assertAllEqual(
          [
              [2, 0, 3],  # c a eos
              [-1, -1, 0],  # "f" == unknown, "e" == unknown, a
          ],
          source_v)
      self.assertAllEqual([2, 3], src_len_v)
      self.assertAllEqual(
          [
              [4, 1, 2],  # sos b c
              [4, 2, 2],  # sos c c
          ],
          target_input_v)
      self.assertAllEqual(
          [
              [1, 2, 3],  # b c eos
              [2, 2, 3],  # c c eos
          ],
          target_output_v)
      self.assertAllEqual([3, 3], tgt_len_v)
      self.assertAllEqual(["4", "1"], annot_v)

      # Second batch.
      (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v,
       annot_v) = (
           sess.run((source, src_seq_len, target_input, target_output,
                     tgt_seq_len, annotation)))
      self.assertAllEqual(
          [[2, 2, 0]],  # c c a
          source_v)
      self.assertAllEqual([3], src_len_v)
      self.assertAllEqual(
          [[4, 0, 1]],  # sos a b
          target_input_v)
      self.assertAllEqual(
          [[0, 1, 3]],  # a b eos
          target_output_v)
      self.assertAllEqual([3], tgt_len_v)
      self.assertAllEqual(["2"], annot_v)

      with self.assertRaisesOpError("End of sequence"):
        sess.run(source)
コード例 #7
0
  def testGetIteratorWithMaxLengthServerMode(self):
    tf.set_random_seed(1)
    tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
        tf.constant(["a", "b", "c", "eos", "sos"]))
    src_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["f e a g", "c c a", "d", "c a"]))
    tgt_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["c c c c", "a b c c", "d", "b"]))
    hparams = tf.contrib.training.HParams(
        random_seed=3,
        num_buckets=1,
        eos="eos",
        sos="sos",
        context_feed="",
        server_mode=True)
    batch_size = 2
    # target length is effectively limited to 4 tokens (3 + sos/eos).
    max_len = 3
    iterator = iterator_utils.get_iterator(
        hparams=hparams,
        src_dataset=src_dataset,
        tgt_dataset=tgt_dataset,
        src_vocab_table=src_vocab_table,
        tgt_vocab_table=tgt_vocab_table,
        batch_size=batch_size,
        sos=hparams.sos,
        eos=hparams.eos,
        random_seed=hparams.random_seed,
        num_buckets=hparams.num_buckets,
        src_max_len=max_len,
        tgt_max_len=max_len,
        reshuffle_each_iteration=False)
    table_initializer = tf.tables_initializer()
    source = iterator.source
    target_input = iterator.target_input
    target_output = iterator.target_output
    src_seq_len = iterator.source_sequence_length
    tgt_seq_len = iterator.target_sequence_length
    self.assertEqual([None, None], source.shape.as_list())
    self.assertEqual([None, None], target_input.shape.as_list())
    self.assertEqual([None, None], target_output.shape.as_list())
    self.assertEqual([None], src_seq_len.shape.as_list())
    self.assertEqual([None], tgt_seq_len.shape.as_list())
    with self.test_session() as sess:
      sess.run(table_initializer)
      sess.run(iterator.initializer)

      (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
          sess.run((source, src_seq_len, target_input, target_output,
                    tgt_seq_len)))
      self.assertAllEqual(
          [
              [2, 0],  # "c" "a"
              [-1, 3],  # "d" == unknown, eos (padding)
          ],
          source_v)
      self.assertAllEqual([2, 1], src_len_v)
      self.assertAllEqual(
          [
              [4, 1],  # sos b
              [4, -1],  # sos unk
          ],
          target_input_v)
      self.assertAllEqual(
          [
              [1, 3],  # b eos
              [-1, 3],  # unk eos
          ],
          target_output_v)
      self.assertAllEqual([2, 2], tgt_len_v)

      (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (
          sess.run((source, src_seq_len, target_input, target_output,
                    tgt_seq_len)))
      self.assertAllEqual(
          [
              [-1, -1, 0],  # unk unk a
              [2, 2, 0],  # c c a
          ],
          source_v)
      self.assertAllEqual([3, 3], src_len_v)
      self.assertAllEqual(
          [
              [4, 2, 2],  # sos c c
              [4, 0, 1]
          ],  # sos a b
          target_input_v)
      self.assertAllEqual(
          [
              [2, 2, 2],  # c c c
              [0, 1, 2],  # a b c
          ],
          target_output_v)
      self.assertAllEqual([3, 3], tgt_len_v)

      with self.assertRaisesOpError("End of sequence"):
        sess.run(source)