def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, extra_args=None): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.contrib.data.TextLineDataset(src_file) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, scope=scope, extra_args=extra_args) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file lbl_vocab_file = hparams.lbl_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table, lbl_vocab_table = \ vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, lbl_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) reverse_lbl_vocab_table = lookup_ops.index_to_string_table_from_file( lbl_vocab_file, default_value=vocab_utils.UNK) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) lbl_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) lbl_dataset = tf.data.TextLineDataset(lbl_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, lbl_dataset, src_vocab_table, tgt_vocab_table, lbl_vocab_table, hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, label_vocab_table=lbl_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, reverse_target_intent_vocab_table=reverse_lbl_vocab_table, scope=scope, extra_args=extra_args) return EvalModel(graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, lbl_file_placeholder=lbl_file_placeholder, iterator=iterator)
def main(_): ano_data_set = os.path.join(cfg.data_set, cfg.ano_data_set) vocab_file = os.path.join(ano_data_set, cfg.tgt_vocab_file) with tf.Graph().as_default(): vocab_size, vocab_file = vocab_utils.check_vocab(vocab_file, out_dir=cfg.out_dir, sos=cfg.sos, eos=cfg.eos, unk=cfg.unk) tgt_vocab_table = vocab_utils.create_vocab_tables(vocab_file) reverse_tgt_vocab_table = vocab_utils.index_to_string_table_from_file( vocab_file, default_value=cfg.unk) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(cfg.sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(cfg.eos)), tf.int32) iter, batch_input = get_iterator(cfg.vaild_tf_filename, tgt_vocab_table, tgt_sos_id, tgt_eos_id) lookUpTgt = reverse_tgt_vocab_table.lookup( tf.to_int64(batch_input.target_output)) sess = tf.Session() sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) sess.run(iter) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) step = 0 try: while True: try: while not coord.should_stop(): src, tgt_output, src_seq_len, tgt_seq_len = \ sess.run([batch_input.source, lookUpTgt, batch_input.source_sequence_length, batch_input.target_sequence_length]) if np.isnan(np.max(src)) or np.isnan(np.min(src)): print('get a nan') exit(1) if np.any(np.less(src, 0.)): print('get a fushu') exit(1) print('run one') step += 1 except tf.errors.OutOfRangeError: print('check finished') exit(1) sess.run(iter) except KeyboardInterrupt: print('interrupt') finally: coord.request_stop() coord.join(threads) sess.close()
def create_train_model(model_creator, get_iterator, hparams, scope=None): """Create the training graph, model and iterator""" # Get the files by concatting prefixes and outputs. src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) vocab_file = hparams.vocab_file # Define the graph graph = tf.Graph() with graph.as_default(): vocab_table = vocab_utils.create_vocab_tables(vocab_file) # Create datasets from file src_dataset = tf.contrib.data.TextLineDataset(src_file) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) # The number of elements of this dataset that should be skipped to form the new dataset. skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) # Iterator iterator = get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, vocab_table=vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, src_reverse=hparams.src_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder ) # Model. We don't give ids_to_words arg because we don't need it for training model = model_creator( hparams=hparams, mode=tf.contrib.learn.ModeKeys.TRAIN, iterator=iterator, vocab_table=vocab_table, scope=scope ) return TrainModel( graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder )
def create_eval_model(model_creator, get_iterator, hparams, scope=None): """Create train graph, model, src/tgt file holders, and iterator.""" vocab_file = hparams.vocab_file # Define the graph graph = tf.Graph() with graph.as_default(): vocab_table = vocab_utils.create_vocab_tables(vocab_file) # Create placeholders for the file location src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) # Create the datasets from file src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder) # Create the iterator for the dataset. We do not use skip_count here as we evaluate on the full file iterator = get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, vocab_table=vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, src_reverse=hparams.src_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer, ) # Create a simple model model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, vocab_table=vocab_table, scope=scope ) return EvalModel( graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator )
def create_infer_model(model_creator, hparams, scope=None, extra_args=None): """Create inference model.""" graph = tf.Graph() src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file lbl_vocab_file = hparams.lbl_vocab_file with graph.as_default(), tf.container(scope or "infer"): src_vocab_table, tgt_vocab_table, lbl_vocab_table = \ vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, lbl_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) reverse_lbl_vocab_table = lookup_ops.index_to_string_table_from_file( lbl_vocab_file, default_value=vocab_utils.UNK) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) iterator = iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=batch_size_placeholder, eos=hparams.eos, src_max_len=hparams.src_max_len_infer) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.INFER, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, label_vocab_table=lbl_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, reverse_target_intent_vocab_table=reverse_lbl_vocab_table, scope=scope, extra_args=extra_args) return InferModel(graph=graph, model=model, src_placeholder=src_placeholder, batch_size_placeholder=batch_size_placeholder, iterator=iterator)
def create_infer_model(model_creator, get_infer_iterator, hparams, verbose=True, scope=None): """Create the inference model""" graph = tf.Graph() vocab_file = hparams.vocab_file with graph.as_default(): # Create the lookup tables vocab_table = vocab_utils.create_vocab_tables(vocab_file) ids_to_words = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocab_file, default_value=vocab_utils.UNK) # Define data placeholders src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) # Create the dataset and iterator src_dataset = tf.contrib.data.Dataset.from_tensor_slices( src_placeholder) iterator = get_infer_iterator(dataset=src_dataset, vocab_table=vocab_table, batch_size=batch_size_placeholder, src_reverse=hparams.src_reverse, eos=hparams.eos, src_max_len=hparams.src_max_len_infer) # Create the model model = model_creator(hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.INFER, vocab_table=vocab_table, verbose=verbose, ids_to_words=ids_to_words, scope=scope) return InferModel(graph=graph, model=model, src_placeholder=src_placeholder, batch_size_placeholder=batch_size_placeholder, iterator=iterator)
def create_eval_model(model_creator, hparams, scope=None, single_cell_fn=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, scope=scope, single_cell_fn=single_cell_fn) return EvalModel( graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator)
def __init__(self, is_training=True, checkPoint_path=None): self.graph = tf.Graph() self.is_training = is_training with self.graph.as_default(): ano_data_set = os.path.join(cfg.data_set, cfg.ano_data_set) vocab_file = os.path.join(ano_data_set, cfg.tgt_vocab_file) vocab_size, vocab_file = vocab_utils.check_vocab( vocab_file, out_dir=cfg.out_dir, sos=cfg.sos, eos=cfg.eos, unk=cfg.unk) self.tgt_vocab_table = vocab_utils.create_vocab_tables(vocab_file) self.reverse_tgt_vocab_table = vocab_utils.index_to_string_table_from_file( vocab_file, default_value=cfg.unk) self.tgt_sos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(cfg.sos)), tf.int32) self.tgt_eos_id = tf.cast( self.tgt_vocab_table.lookup(tf.constant(cfg.eos)), tf.int32) if is_training: # train_src_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.train_src_dataset)) # train_tgt_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.train_tgt_dataset)) self.init_iter_train, self.iterator_train = get_iterator( cfg.train_tf_filename, self.tgt_vocab_table, self.tgt_sos_id, self.tgt_eos_id, augment=True) # vaild_src_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.vaild_src_dataset)) # vaild_tgt_dataset = tf.contrib.data.TextLineDataset(os.path.join(ano_data_set, cfg.vaild_tgt_dataset)) self.init_iter_vaild, self.iterator_vaild = get_iterator( cfg.vaild_tf_filename, self.tgt_vocab_table, self.tgt_sos_id, self.tgt_eos_id) else: self.source = tf.placeholder(tf.float32, (None, None), name='source') batch_source = tf.expand_dims(tf.expand_dims(self.source, axis=0), axis=-1) iterator_source = normalize_input_img(batch_source) self.source_sequence_length = tf.constant( tf.shape(iterator_source)[2], tf.int32) self.iterator = BatchedInput( source=iterator_source, target_input=None, target_output=None, source_sequence_length=self.source_sequence_length, target_sequence_length=None) self.featureCNN = FeatureCNN() self.gru_att_cov = GRU_Att_Cov(vocab_size) #词表size if is_training: if cfg.outer_batch_size: outer_loss = 0 with tf.variable_scope('outer_batch_size') as scope: for i in range(cfg.outer_batch_size): if i > 0: scope.reuse_variables() self.cnn_out_train = self.featureCNN( self.iterator_train.source, True, False) self.logits_train, _, self.attn_dists_train = self.gru_att_cov( self.cnn_out_train, self.iterator_train, True, self.tgt_sos_id) outer_loss += self._loss(self.logits_train, self.iterator_train) self.loss_train = outer_loss / cfg.outer_batch_size else: self.cnn_out_train = self.featureCNN( self.iterator_train.source, True, False) self.logits_train, _, self.attn_dists_train = self.gru_att_cov( self.cnn_out_train, self.iterator_train, True, self.tgt_sos_id) self.loss_train = self._loss(self.logits_train, self.iterator_train) self.global_step = tf.Variable(0, name='global_step', trainable=False) self.learning_rate = tf.train.exponential_decay( cfg.startLr, self.global_step, cfg.decay_steps, cfg.decay_rate) optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) self.train_op = optimizer.minimize( self.loss_train, global_step=self.global_step) self.cnn_out_vaild = self.featureCNN( self.iterator_vaild.source, True) self.logits_vaild, _, _ = self.gru_att_cov( self.cnn_out_vaild, self.iterator_vaild, True, self.tgt_sos_id) self.loss_vaild = self._loss(self.logits_vaild, self.iterator_vaild) self.cnn_out_vaild_infer = self.featureCNN( self.iterator_vaild.source, False) _, self.infer_indes_vaild, self.infer_attn_dists_vaild = self.gru_att_cov( self.cnn_out_vaild_infer, self.iterator_vaild, False, self.tgt_sos_id) self.infer_lookUpTgt_vaild = self.reverse_tgt_vocab_table.lookup( tf.to_int64(self.infer_indes_vaild)) self.accuracy_vaild = self._acc( self.infer_indes_vaild, self.iterator_vaild.target_output) self.train_lookUpTgt_vaild = self.reverse_tgt_vocab_table.lookup( tf.to_int64(self.iterator_vaild.target_output)) self.train_summary, self.vaild_summary = self._summary() else: self.cnn_out = self.featureCNN(self.iterator.source, is_training) _, self.infer_indes, self.infer_attn_dists = self.gru_att_cov( self.cnn_out, self.iterator, False, self.tgt_sos_id) self.infer_lookUpTgt = self.reverse_tgt_vocab_table.lookup( tf.to_int64(self.infer_indes)) self.init = [ tf.global_variables_initializer(), tf.tables_initializer() ] self.saver = tf.train.Saver() self.sess = tf.Session(config=tf.ConfigProto( log_device_placement=True)) if not is_training: self.sess.run(self.init) self.saver.restore(self.sess, checkPoint_path)
def create_train_model( model_creator, hparams, scope=None, single_cell_fn=None, model_device_fn=None): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.contrib.data.TextLineDataset(src_file) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) if hparams.curriculum == 'none': iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder) else: iterator = iterator_utils.get_feedable_iterator( hparams, src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. with tf.device(model_device_fn): model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, scope=scope, single_cell_fn=single_cell_fn) return TrainModel( graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)