def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, extra_args=None): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, scope=scope, extra_args=extra_args) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def create_test_iterator(hparams, mode): """Create test iterator.""" src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant([hparams.eos, "a", "b", "c", "d"])) tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) if mode == tf.contrib.learn.ModeKeys.INFER: reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( tgt_vocab_mapping) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a a b b c", "a b b"])) if mode != tf.contrib.learn.ModeKeys.INFER: tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a b c b c", "a b c b"])) return ( iterator_utils.get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets), src_vocab_table, tgt_vocab_table) else: return ( iterator_utils.get_infer_iterator( src_dataset=src_dataset, src_vocab_table=src_vocab_table, eos=hparams.eos, batch_size=hparams.batch_size), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer) model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, scope=scope, extra_args=extra_args) return EvalModel(graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator)
def testGetIterator(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c", "a b", "", "b c"])) hparams = tf.contrib.training.HParams(random_seed=3, num_buckets=5, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 iterator = iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() source = iterator.source target_input = iterator.target_input target_output = iterator.target_output src_seq_len = iterator.source_sequence_length tgt_seq_len = iterator.target_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None, None], target_input.shape.as_list()) self.assertEqual([None, None], target_output.shape.as_list()) self.assertEqual([None], src_seq_len.shape.as_list()) self.assertEqual([None], tgt_seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [ [-1, -1, 0], # "f" == unknown, "e" == unknown, a [2, 0, 3] ], # c a eos -- eos is padding source_v) self.assertAllEqual([3, 2], src_len_v) self.assertAllEqual( [ [4, 2, 2], # sos c c [4, 1, 2] ], # sos b c target_input_v) self.assertAllEqual( [ [2, 2, 3], # c c eos [1, 2, 3] ], # b c eos target_output_v) self.assertAllEqual([3, 3], tgt_len_v) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[2, 2, 0]], # c c a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 0, 1]], # sos a b target_input_v) self.assertAllEqual( [[0, 1, 3]], # a b eos target_output_v) self.assertAllEqual([3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source)