def create_test_iterator(hparams, mode, trie_excludes=None): """Create test iterator.""" src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant([hparams.eos, "a", "b", "c", "d"])) tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( tgt_vocab_mapping) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a a b b c", "a b b"])) ctx_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c b c b a", "b c b a"])) trie_excludes = trie_excludes or [] trie_excludes = " {} ".format(hparams.eos).join(trie_excludes) tex_dataset = tf.data.Dataset.from_tensor_slices( tf.constant([trie_excludes, trie_excludes])) if mode != tf.contrib.learn.ModeKeys.INFER: tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a b c b c", "a b c b"])) return (iterator_utils.get_iterator(hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, ctx_dataset=ctx_dataset, annot_dataset=None, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table) else: return (iterator_utils.get_infer_iterator( hparams=hparams, src_dataset=src_dataset, ctx_dataset=ctx_dataset, annot_dataset=None, trie_exclude_dataset=tex_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, eos=hparams.eos, batch_size=hparams.batch_size), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table)
def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, graph=None, extra_args=None, trie=None, use_placeholders=False): """Create train graph, model, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file if not graph: graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) annot_placeholder = None src_placeholder = None tgt_placeholder = None annot_dataset = None ctx_dataset = None if use_placeholders: src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string) tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder) if hparams.use_rl: annot_placeholder = tf.placeholder(shape=[None], dtype=tf.string) annot_dataset = tf.data.Dataset.from_tensor_slices(annot_placeholder) else: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) ctx_file = None if hparams.ctx is not None: ctx_file = "%s.%s" % (hparams.train_prefix, hparams.ctx) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) if hparams.train_annotations is not None: annot_dataset = tf.data.TextLineDataset(hparams.train_annotations) if ctx_file is not None: ctx_dataset = tf.data.TextLineDataset(ctx_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, ctx_dataset=ctx_dataset, annot_dataset=annot_dataset, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args, trie=trie) return TrainModel( graph=graph, model=model, iterator=iterator, src_placeholder=src_placeholder, tgt_placeholder=tgt_placeholder, annot_placeholder=annot_placeholder, skip_count_placeholder=skip_count_placeholder)
def create_eval_model(model_creator, hparams, scope=None, graph=None, extra_args=None, trie=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file if not graph: graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) ctx_file_placeholder = None if hparams.ctx is not None: ctx_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) annot_file_placeholder = None if hparams.dev_annotations is not None: annot_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) ctx_dataset = None if ctx_file_placeholder is not None: ctx_dataset = tf.data.TextLineDataset(ctx_file_placeholder) annot_dataset = None if annot_file_placeholder is not None: annot_dataset = tf.data.TextLineDataset(annot_file_placeholder) iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, ctx_dataset=ctx_dataset, annot_dataset=annot_dataset, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer) model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args, trie=trie) return EvalModel( graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, ctx_file_placeholder=ctx_file_placeholder, annot_file_placeholder=annot_file_placeholder, iterator=iterator)
def create_train_model_for_server(model_creator, hparams, scope=None, num_workers=1, jobid=0, graph=None, extra_args=None, trie=None): """Create graph, model, and iterator when running the NMT in server mode. This is different from the standard training model, because the input arrives via RPC and thus has to be fed using placeholders.""" assert hparams.num_buckets == 1, "No bucketing when in server mode." src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file if not graph: graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string) tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder) wgt_placeholder = tf.placeholder(shape=[None], dtype=tf.float32) wgt_dataset = tf.data.Dataset.from_tensor_slices(wgt_placeholder) ctx_placeholder = None if hparams.ctx is not None: ctx_placeholder = tf.placeholder(shape=[None], dtype=tf.string) ctx_dataset = None if ctx_placeholder is not None: ctx_dataset = tf.data.Dataset.from_tensor_slices(ctx_placeholder) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, wgt_dataset=wgt_dataset, ctx_dataset=ctx_dataset, annot_dataset=None, batch_size=batch_size_placeholder, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args, trie=trie) return TrainModelForServer( graph=graph, model=model, iterator=iterator, src_placeholder=src_placeholder, tgt_placeholder=tgt_placeholder, wgt_placeholder=wgt_placeholder, batch_size_placeholder=batch_size_placeholder, skip_count_placeholder=skip_count_placeholder)
def testGetIteratorWithSkipCount(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c a", "c c a", "d", "f e a g"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["b c", "a b", "", "c c"])) hparams = tf.contrib.training.HParams( random_seed=1, num_buckets=5, eos="eos", sos="sos", context_feed="", server_mode=False) batch_size = 2 src_max_len = 3 skip_count = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, skip_count=skip_count, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() source = iterator.source target_input = iterator.target_input target_output = iterator.target_output src_seq_len = iterator.source_sequence_length tgt_seq_len = iterator.target_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None, None], target_input.shape.as_list()) self.assertEqual([None, None], target_output.shape.as_list()) self.assertEqual([None], src_seq_len.shape.as_list()) self.assertEqual([None], tgt_seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer, feed_dict={skip_count: 3}) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[-1, -1, 0]], # "f" == unknown, "e" == unknown, a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 2, 2]], # sos c c target_input_v) self.assertAllEqual( [[2, 2, 3]], # c c eos target_output_v) self.assertAllEqual([3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source) # Re-init iterator with skip_count=0. sess.run(iterator.initializer, feed_dict={skip_count: 0}) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [ [2, 0, 3], # c a eos -- eos is padding [-1, -1, 0] # "f" == unknown, "e" == unknown, a ], source_v) self.assertAllEqual([2, 3], src_len_v) self.assertAllEqual( [ [4, 1, 2], # sos b c [4, 2, 2], # sos c c ], target_input_v) self.assertAllEqual( [ [1, 2, 3], # b c eos [2, 2, 3], # c c eos ], target_output_v) self.assertAllEqual([3, 3], tgt_len_v) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[2, 2, 0]], # c c a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 0, 1]], # sos a b target_input_v) self.assertAllEqual( [[0, 1, 3]], # a b eos target_output_v) self.assertAllEqual([3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source)
def testGetIteratorWithAnnotations(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c", "a b", "", "b c"])) annot_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["1\t1", "2\t1", "3\t1", "4\t1"])) hparams = tf.contrib.training.HParams( random_seed=3, num_buckets=5, eos="eos", sos="sos", context_feed="", server_mode=False) batch_size = 2 src_max_len = 3 iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, annot_dataset=annot_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() source = iterator.source target_input = iterator.target_input target_output = iterator.target_output src_seq_len = iterator.source_sequence_length tgt_seq_len = iterator.target_sequence_length annotation = iterator.annotation self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None, None], target_input.shape.as_list()) self.assertEqual([None, None], target_output.shape.as_list()) self.assertEqual([None], src_seq_len.shape.as_list()) self.assertEqual([None], tgt_seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) # First batch. (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v, annot_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len, annotation))) self.assertAllEqual( [ [2, 0, 3], # c a eos [-1, -1, 0], # "f" == unknown, "e" == unknown, a ], source_v) self.assertAllEqual([2, 3], src_len_v) self.assertAllEqual( [ [4, 1, 2], # sos b c [4, 2, 2], # sos c c ], target_input_v) self.assertAllEqual( [ [1, 2, 3], # b c eos [2, 2, 3], # c c eos ], target_output_v) self.assertAllEqual([3, 3], tgt_len_v) self.assertAllEqual(["4", "1"], annot_v) # Second batch. (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v, annot_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len, annotation))) self.assertAllEqual( [[2, 2, 0]], # c c a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 0, 1]], # sos a b target_input_v) self.assertAllEqual( [[0, 1, 3]], # a b eos target_output_v) self.assertAllEqual([3], tgt_len_v) self.assertAllEqual(["2"], annot_v) with self.assertRaisesOpError("End of sequence"): sess.run(source)
def testGetIteratorWithMaxLengthServerMode(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c c c", "a b c c", "d", "b"])) hparams = tf.contrib.training.HParams( random_seed=3, num_buckets=1, eos="eos", sos="sos", context_feed="", server_mode=True) batch_size = 2 # target length is effectively limited to 4 tokens (3 + sos/eos). max_len = 3 iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=max_len, tgt_max_len=max_len, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() source = iterator.source target_input = iterator.target_input target_output = iterator.target_output src_seq_len = iterator.source_sequence_length tgt_seq_len = iterator.target_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None, None], target_input.shape.as_list()) self.assertEqual([None, None], target_output.shape.as_list()) self.assertEqual([None], src_seq_len.shape.as_list()) self.assertEqual([None], tgt_seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [ [2, 0], # "c" "a" [-1, 3], # "d" == unknown, eos (padding) ], source_v) self.assertAllEqual([2, 1], src_len_v) self.assertAllEqual( [ [4, 1], # sos b [4, -1], # sos unk ], target_input_v) self.assertAllEqual( [ [1, 3], # b eos [-1, 3], # unk eos ], target_output_v) self.assertAllEqual([2, 2], tgt_len_v) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [ [-1, -1, 0], # unk unk a [2, 2, 0], # c c a ], source_v) self.assertAllEqual([3, 3], src_len_v) self.assertAllEqual( [ [4, 2, 2], # sos c c [4, 0, 1] ], # sos a b target_input_v) self.assertAllEqual( [ [2, 2, 2], # c c c [0, 1, 2], # a b c ], target_output_v) self.assertAllEqual([3, 3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source)