def create_train_model(): train_graph = tf.Graph() mode = tf.contrib.learn.ModeKeys.TRAIN train_model = Model(mode, hyper_parameters) dataset_iterator = DatasetIterator(hyper_parameters) train = Train(mode, hyper_parameters) with train_graph.as_default(): source_vocab_table, target_vocab_table = dataset_iterator.get_tables( share_vocab=False) source_dataset, target_dataset = dataset_iterator.get_datasets() train_iterator = dataset_iterator.get_iterator( source_vocab_table=source_vocab_table, target_vocab_table=target_vocab_table, source_dataset=source_dataset, target_dataset=target_dataset, source_max_len=hyper_parameters["source_max_len_train"], target_max_len=hyper_parameters["target_max_len_train"], #skip_count=skip_count_place_holder #todo probably we need this ) logits, loss, final_context_state, sample_id = train_model.build_model( train_iterator, target_vocab_table) train.configure_train_eval_infer(iterator=train_iterator, logits=logits, loss=loss, sample_id=sample_id, final_state=final_context_state) return train_graph, train_iterator, train
def create_eval_model(): eval_graph = tf.Graph() mode = tf.contrib.learn.ModeKeys.EVAL eval_model = Model(mode, hyper_parameters) dataset_iterator = DatasetIterator(hyper_parameters) eval = Train(mode, hyper_parameters) with eval_graph.as_default(): source_vocab_file = hyper_parameters["source_vocab_file"] target_vocab_file = hyper_parameters["target_vocab_file"] source_vocab_table, target_vocab_table = dataset_iterator.get_tables( share_vocab=False) reverse_target_vocab_table = tf.contrib.lookup.index_to_string_table_from_file( target_vocab_file, default_value="UNK") source_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) target_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) source_dataset = tf.data.TextLineDataset(source_file_placeholder) target_dataset = tf.data.TextLineDataset(target_file_placeholder) eval_iterator = dataset_iterator.get_iterator( source_vocab_table=source_vocab_table, target_vocab_table=target_vocab_table, source_dataset=source_dataset, target_dataset=target_dataset, source_max_len=hyper_parameters["source_max_len_infer"], target_max_len=hyper_parameters["target_max_len_infer"]) logits, loss, final_context_state, sample_id = eval_model.build_model( eval_iterator, target_vocab_table) eval.configure_train_eval_infer( iterator=eval_iterator, logits=logits, loss=loss, sample_id=sample_id, final_state=final_context_state, reverse_target_vocab_table=reverse_target_vocab_table) return eval_graph, eval, eval_iterator, source_file_placeholder, target_file_placeholder
def create_infer_model(): infer_graph = tf.Graph() mode = tf.contrib.learn.ModeKeys.INFER infer_model = Model(mode, hyper_parameters) dataset_iterator = DatasetIterator(hyper_parameters) infer = Train(mode, hyper_parameters) with infer_graph.as_default(): source_vocab_file = hyper_parameters["source_vocab_file"] target_vocab_file = hyper_parameters["target_vocab_file"] reverse_target_vocab_table = tf.contrib.lookup.index_to_string_table_from_file( target_vocab_file, default_value="UNK") source_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) source_dataset = tf.data.Dataset.from_tensor_slices(source_placeholder) source_vocab_table, target_vocab_table = dataset_iterator.get_tables( share_vocab=False) infer_iterator = dataset_iterator.get_infer_iterator( source_dataset=source_dataset, source_vocab_table=source_vocab_table, source_max_len=hyper_parameters["source_max_len_infer"]) logits, loss, final_context_state, sample_id = infer_model.build_model( infer_iterator, target_vocab_table) infer.configure_train_eval_infer( iterator=infer_iterator, logits=logits, loss=loss, sample_id=sample_id, final_state=final_context_state, reverse_target_vocab_table=reverse_target_vocab_table) return infer_graph, infer, infer_iterator, source_placeholder, batch_size_placeholder