Ejemplo n.º 1
0
def create_train_model(model_creator,
                       hparams,
                       scope=None,
                       num_workers=1,
                       jobid=0,
                       extra_args=None):
    """Create train graph, model, and iterator."""
    src_file = "%s.%s" % (hparams.train_prefix, hparams.src)
    tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt)
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)

        src_dataset = tf.data.TextLineDataset(src_file)
        tgt_dataset = tf.data.TextLineDataset(tgt_file)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len,
            tgt_max_len=hparams.tgt_max_len,
            skip_count=skip_count_placeholder,
            num_shards=num_workers,
            shard_index=jobid)

        # Note: One can set model_device_fn to
        # `tf.train.replica_device_setter(ps_tasks)` for distributed training.
        model_device_fn = None
        if extra_args: model_device_fn = extra_args.model_device_fn
        with tf.device(model_device_fn):
            model = model_creator(hparams,
                                  iterator=iterator,
                                  mode=tf.contrib.learn.ModeKeys.TRAIN,
                                  source_vocab_table=src_vocab_table,
                                  target_vocab_table=tgt_vocab_table,
                                  scope=scope,
                                  extra_args=extra_args)

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
Ejemplo n.º 2
0
def create_test_iterator(hparams, mode):
  """Create test iterator."""
  src_vocab_table = lookup_ops.index_table_from_tensor(
      tf.constant([hparams.eos, "a", "b", "c", "d"]))
  tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"])
  tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping)
  if mode == tf.contrib.learn.ModeKeys.INFER:
    reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor(
        tgt_vocab_mapping)

  src_dataset = tf.data.Dataset.from_tensor_slices(
      tf.constant(["a a b b c", "a b b"]))

  if mode != tf.contrib.learn.ModeKeys.INFER:
    tgt_dataset = tf.data.Dataset.from_tensor_slices(
        tf.constant(["a b c b c", "a b c b"]))
    return (
        iterator_utils.get_iterator(
            src_dataset=src_dataset,
            tgt_dataset=tgt_dataset,
            src_vocab_table=src_vocab_table,
            tgt_vocab_table=tgt_vocab_table,
            batch_size=hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets),
        src_vocab_table,
        tgt_vocab_table)
  else:
    return (
        iterator_utils.get_infer_iterator(
            src_dataset=src_dataset,
            src_vocab_table=src_vocab_table,
            eos=hparams.eos,
            batch_size=hparams.batch_size),
        src_vocab_table,
        tgt_vocab_table,
        reverse_tgt_vocab_table)
Ejemplo n.º 3
0
def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
    """Create train graph, model, src/tgt file holders, and iterator."""
    src_vocab_file = hparams.src_vocab_file
    tgt_vocab_file = hparams.tgt_vocab_file
    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
            src_vocab_file, tgt_vocab_file, hparams.share_vocab)
        src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        src_dataset = tf.data.TextLineDataset(src_file_placeholder)
        tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
        iterator = iterator_utils.get_iterator(
            src_dataset,
            tgt_dataset,
            src_vocab_table,
            tgt_vocab_table,
            hparams.batch_size,
            sos=hparams.sos,
            eos=hparams.eos,
            random_seed=hparams.random_seed,
            num_buckets=hparams.num_buckets,
            src_max_len=hparams.src_max_len_infer,
            tgt_max_len=hparams.tgt_max_len_infer)
        model = model_creator(hparams,
                              iterator=iterator,
                              mode=tf.contrib.learn.ModeKeys.EVAL,
                              source_vocab_table=src_vocab_table,
                              target_vocab_table=tgt_vocab_table,
                              scope=scope,
                              extra_args=extra_args)
    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_file_placeholder,
                     tgt_file_placeholder=tgt_file_placeholder,
                     iterator=iterator)
Ejemplo n.º 4
0
    def testGetIterator(self):
        tf.set_random_seed(1)
        tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor(
            tf.constant(["a", "b", "c", "eos", "sos"]))
        src_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["f e a g", "c c a", "d", "c a"]))
        tgt_dataset = tf.data.Dataset.from_tensor_slices(
            tf.constant(["c c", "a b", "", "b c"]))
        hparams = tf.contrib.training.HParams(random_seed=3,
                                              num_buckets=5,
                                              eos="eos",
                                              sos="sos")
        batch_size = 2
        src_max_len = 3
        iterator = iterator_utils.get_iterator(src_dataset=src_dataset,
                                               tgt_dataset=tgt_dataset,
                                               src_vocab_table=src_vocab_table,
                                               tgt_vocab_table=tgt_vocab_table,
                                               batch_size=batch_size,
                                               sos=hparams.sos,
                                               eos=hparams.eos,
                                               random_seed=hparams.random_seed,
                                               num_buckets=hparams.num_buckets,
                                               src_max_len=src_max_len,
                                               reshuffle_each_iteration=False)
        table_initializer = tf.tables_initializer()
        source = iterator.source
        target_input = iterator.target_input
        target_output = iterator.target_output
        src_seq_len = iterator.source_sequence_length
        tgt_seq_len = iterator.target_sequence_length
        self.assertEqual([None, None], source.shape.as_list())
        self.assertEqual([None, None], target_input.shape.as_list())
        self.assertEqual([None, None], target_output.shape.as_list())
        self.assertEqual([None], src_seq_len.shape.as_list())
        self.assertEqual([None], tgt_seq_len.shape.as_list())
        with self.test_session() as sess:
            sess.run(table_initializer)
            sess.run(iterator.initializer)

            (source_v, src_len_v, target_input_v, target_output_v,
             tgt_len_v) = (sess.run((source, src_seq_len, target_input,
                                     target_output, tgt_seq_len)))
            self.assertAllEqual(
                [
                    [-1, -1, 0],  # "f" == unknown, "e" == unknown, a
                    [2, 0, 3]
                ],  # c a eos -- eos is padding
                source_v)
            self.assertAllEqual([3, 2], src_len_v)
            self.assertAllEqual(
                [
                    [4, 2, 2],  # sos c c
                    [4, 1, 2]
                ],  # sos b c
                target_input_v)
            self.assertAllEqual(
                [
                    [2, 2, 3],  # c c eos
                    [1, 2, 3]
                ],  # b c eos
                target_output_v)
            self.assertAllEqual([3, 3], tgt_len_v)

            (source_v, src_len_v, target_input_v, target_output_v,
             tgt_len_v) = (sess.run((source, src_seq_len, target_input,
                                     target_output, tgt_seq_len)))
            self.assertAllEqual(
                [[2, 2, 0]],  # c c a
                source_v)
            self.assertAllEqual([3], src_len_v)
            self.assertAllEqual(
                [[4, 0, 1]],  # sos a b
                target_input_v)
            self.assertAllEqual(
                [[0, 1, 3]],  # a b eos
                target_output_v)
            self.assertAllEqual([3], tgt_len_v)

            with self.assertRaisesOpError("End of sequence"):
                sess.run(source)