def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, extra_args=None): """Create graph, model and iterator for training.""" graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) data_dataset = tf.data.TextLineDataset(hparams.train_data) kb_dataset = tf.data.TextLineDataset(hparams.train_kb) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) reverse_vocab_table = lookup_ops.index_to_string_table_from_file( hparams.vocab_file, default_value=vocab_utils.UNK) # this is the actual train_iterator train_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # this is the placeholder iterator. One can use this placeholder iterator # to switch between training and evauation. handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types, train_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.TRAIN, vocab_table=vocab_table, scope=scope, extra_args=extra_args, reverse_vocab_table=reverse_vocab_table) return TrainModel(graph=graph, model=model, placeholder_iterator=iterator, train_iterator=train_iterator, placeholder_handle=handle, skip_count_placeholder=skip_count_placeholder)
def create_selfplay_model(model_creator, is_mutable, num_workers, jobid, hparams, scope=None, extra_args=None): """create slef play models.""" graph = tf.Graph() with graph.as_default(), tf.container(scope or "selfplay"): vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) reverse_vocab_table = lookup_ops.index_to_string_table_from_file( hparams.vocab_file, default_value=vocab_utils.UNK) if is_mutable: mutable_index = 0 else: mutable_index = 1 # get a list of iterators and placeholders iterators, placeholders = self_play_iterator_creator( hparams, num_workers, jobid) train_iterator, self_play_fulltext_iterator, self_play_structured_iterator = iterators data_placeholder, kb_placeholder, batch_size_placeholder, skip_count_placeholder = placeholders # get an iterator handler handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types, train_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=[ dialogue_utils.mode_self_play_mutable, dialogue_utils.mode_self_play_immutable ][mutable_index], vocab_table=vocab_table, reverse_vocab_table=reverse_vocab_table, scope=scope, extra_args=extra_args) return SelfplayModel(graph=graph, model=model, placeholder_iterator=iterator, placeholder_handle=handle, train_iterator=train_iterator, self_play_ft_iterator=self_play_fulltext_iterator, self_play_st_iterator=self_play_structured_iterator, data_placeholder=data_placeholder, kb_placeholder=kb_placeholder, skip_count_placeholder=skip_count_placeholder, batch_size_placeholder=batch_size_placeholder)
def create_infer_model(model_creator, hparams, scope=None, extra_args=None): """Create inference model.""" graph = tf.Graph() with graph.as_default(), tf.container(scope or "infer"): vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) reverse_vocab_table = lookup_ops.index_to_string_table_from_file( hparams.vocab_file, default_value=vocab_utils.UNK) data_src_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="src_ph") kb_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="kb_ph") batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64, name="bs_ph") data_src_dataset = tf.data.Dataset.from_tensor_slices( data_src_placeholder) kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder) # this is the actual infer iterator infer_iterator = iterator_utils.get_infer_iterator( data_src_dataset, kb_dataset, vocab_table, batch_size=batch_size_placeholder, eod=hparams.eod, len_action=hparams.len_action) # this is the placeholder infer iterator handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, infer_iterator.output_types, infer_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.INFER, vocab_table=vocab_table, reverse_vocab_table=reverse_vocab_table, scope=scope, extra_args=extra_args) return InferModel(graph=graph, model=model, placeholder_iterator=iterator, placeholder_handle=handle, infer_iterator=infer_iterator, data_src_placeholder=data_src_placeholder, kb_placeholder=kb_placeholder, batch_size_placeholder=batch_size_placeholder)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" vocab_file = hparams.vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): vocab_table = vocab_utils.create_vocab_tables(vocab_file) data_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) kb_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) data_dataset = tf.data.TextLineDataset(data_file_placeholder) kb_dataset = tf.data.TextLineDataset(kb_file_placeholder) # this is the eval_actual iterator eval_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len) # this is the placeholder iterator handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, eval_iterator.output_types, eval_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.EVAL, vocab_table=vocab_table, scope=scope, extra_args=extra_args) return EvalModel(graph=graph, model=model, placeholder_iterator=iterator, placeholder_handle=handle, eval_iterator=eval_iterator, data_file_placeholder=data_file_placeholder, kb_file_placeholder=kb_file_placeholder)