def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, graph=None, extra_args=None, trie=None, use_placeholders=False): """Create train graph, model, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file if not graph: graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) annot_placeholder = None src_placeholder = None tgt_placeholder = None annot_dataset = None ctx_dataset = None if use_placeholders: src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string) tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder) if hparams.use_rl: annot_placeholder = tf.placeholder(shape=[None], dtype=tf.string) annot_dataset = tf.data.Dataset.from_tensor_slices(annot_placeholder) else: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) ctx_file = None if hparams.ctx is not None: ctx_file = "%s.%s" % (hparams.train_prefix, hparams.ctx) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) if hparams.train_annotations is not None: annot_dataset = tf.data.TextLineDataset(hparams.train_annotations) if ctx_file is not None: ctx_dataset = tf.data.TextLineDataset(ctx_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, ctx_dataset=ctx_dataset, annot_dataset=annot_dataset, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args, trie=trie) return TrainModel( graph=graph, model=model, iterator=iterator, src_placeholder=src_placeholder, tgt_placeholder=tgt_placeholder, annot_placeholder=annot_placeholder, skip_count_placeholder=skip_count_placeholder)
def create_eval_model(model_creator, hparams, scope=None, graph=None, extra_args=None, trie=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file if not graph: graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) ctx_file_placeholder = None if hparams.ctx is not None: ctx_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) annot_file_placeholder = None if hparams.dev_annotations is not None: annot_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) ctx_dataset = None if ctx_file_placeholder is not None: ctx_dataset = tf.data.TextLineDataset(ctx_file_placeholder) annot_dataset = None if annot_file_placeholder is not None: annot_dataset = tf.data.TextLineDataset(annot_file_placeholder) iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, ctx_dataset=ctx_dataset, annot_dataset=annot_dataset, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer) model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args, trie=trie) return EvalModel( graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, ctx_file_placeholder=ctx_file_placeholder, annot_file_placeholder=annot_file_placeholder, iterator=iterator)
def create_infer_model(model_creator, hparams, scope=None, graph=None, extra_args=None, trie=None): """Create inference model.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file if not graph: graph = tf.Graph() with graph.as_default(), tf.container(scope or "infer"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) ctx_placeholder = None if hparams.ctx is not None: ctx_placeholder = tf.placeholder(shape=[None], dtype=tf.string) annot_placeholder = None if hparams.use_rl: annot_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) ctx_dataset = None if ctx_placeholder is not None: ctx_dataset = tf.data.Dataset.from_tensor_slices(ctx_placeholder) annot_dataset = None if annot_placeholder is not None: annot_dataset = tf.data.Dataset.from_tensor_slices(annot_placeholder) trie_exclude_placeholder = None trie_exclude_dataset = None if hparams.infer_mode.startswith("trie_"): trie_exclude_placeholder = tf.placeholder(shape=[None], dtype=tf.string) trie_exclude_dataset = tf.data.Dataset.from_tensor_slices( trie_exclude_placeholder) iterator = iterator_utils.get_infer_iterator( hparams=hparams, src_dataset=src_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, ctx_dataset=ctx_dataset, annot_dataset=annot_dataset, trie_exclude_dataset=trie_exclude_dataset, batch_size=batch_size_placeholder, eos=hparams.eos, src_max_len=hparams.src_max_len_infer) model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.INFER, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args, trie=trie) return InferModel( graph=graph, model=model, src_placeholder=src_placeholder, annot_placeholder=annot_placeholder, trie_exclude_placeholder=trie_exclude_placeholder, ctx_placeholder=ctx_placeholder, batch_size_placeholder=batch_size_placeholder, iterator=iterator)
def create_train_model_for_server(model_creator, hparams, scope=None, num_workers=1, jobid=0, graph=None, extra_args=None, trie=None): """Create graph, model, and iterator when running the NMT in server mode. This is different from the standard training model, because the input arrives via RPC and thus has to be fed using placeholders.""" assert hparams.num_buckets == 1, "No bucketing when in server mode." src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file if not graph: graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) tgt_placeholder = tf.placeholder(shape=[None], dtype=tf.string) tgt_dataset = tf.data.Dataset.from_tensor_slices(tgt_placeholder) wgt_placeholder = tf.placeholder(shape=[None], dtype=tf.float32) wgt_dataset = tf.data.Dataset.from_tensor_slices(wgt_placeholder) ctx_placeholder = None if hparams.ctx is not None: ctx_placeholder = tf.placeholder(shape=[None], dtype=tf.string) ctx_dataset = None if ctx_placeholder is not None: ctx_dataset = tf.data.Dataset.from_tensor_slices(ctx_placeholder) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) iterator = iterator_utils.get_iterator( hparams=hparams, src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, wgt_dataset=wgt_dataset, ctx_dataset=ctx_dataset, annot_dataset=None, batch_size=batch_size_placeholder, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args, trie=trie) return TrainModelForServer( graph=graph, model=model, iterator=iterator, src_placeholder=src_placeholder, tgt_placeholder=tgt_placeholder, wgt_placeholder=wgt_placeholder, batch_size_placeholder=batch_size_placeholder, skip_count_placeholder=skip_count_placeholder)