def _input_fn(params): """Input function.""" del params if mode == tf.contrib.learn.ModeKeys.TRAIN: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) else: if hparams.mode == "translate": src_file = hparams.translate_file + ".tok" tgt_file = hparams.translate_file + ".tok" else: src_file = "%s.%s" % (hparams.test_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) if mode == tf.contrib.learn.ModeKeys.TRAIN: # Run one epoch and stop if running train_and_eval. if hparams.mode == "train_and_eval": # In this mode input pipeline is restarted every epoch, so choose a # different random_seed. num_repeat = 1 random_seed = hparams.random_seed + int(time.time()) % 100 else: num_repeat = 8 random_seed = hparams.random_seed return iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, output_buffer_size=None, skip_count=None, num_shards=1, # flags.num_workers shard_index=0, # flags.jobid reshuffle_each_iteration=True, use_char_encode=hparams.use_char_encode, num_repeat=num_repeat, filter_oversized_sequences=True ) # need to update get_effective_train_epoch_size() if this flag flips. else: return iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=hparams.infer_batch_size, eos=hparams.eos, src_max_len=hparams.src_max_len, use_char_encode=hparams.use_char_encode)
def _createModel(mode, hparam, modelFunc=None): ''' 根据hparam: train_src,train_tgt,创建数据集,然会返回TrainModel TrainModels是和神经网络输入相关的 封装 :param hparam: :return: ''' def _get_config_proto(): conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) return conf graph = tf.Graph() with graph.as_default(): skipCount = None if mode == 'train': src_dataset = tf.data.TextLineDataset(hparam.train_src) tgt_dataset = tf.data.TextLineDataset(hparam.train_tgt) skipCount = tf.placeholder(tf.int64, shape=()) elif mode == 'eval': src_dataset = tf.data.TextLineDataset(hparam.dev_src) tgt_dataset = tf.data.TextLineDataset(hparam.dev_tgt) else: raise ValueError('_createTrainModel.mode must be train or eval') batch_input = get_iterator(src_dataset, tgt_dataset, skipCount, hparam) sess = tf.Session(config=_get_config_proto()) model = modelFunc(batch_input, mode, hparam) if mode == 'train': return TrainModel(model, sess, graph, batch_input) elif mode == 'eval': return EvalModel(model, sess, graph, batch_input)
def testGraphVariablesDevMode(self): with self.graph.as_default(): iterator, total_num = get_iterator('DEV', filesobj = files, buffer_size = hparams.buffer_size, num_epochs = hparams.num_epochs, batch_size = hparams.batch_size, debug_mode = True) _ = BuildEvalModel(hparams, iterator, tf.get_default_graph()) var_names = [var.name for var in tf.trainable_variables()] self.assertAllEqual(sorted(var_names), sorted(list(expected_variables.keys())), 'variables are not compatible dev mode') with self.graph.as_default(): for var in tf.trainable_variables(): self.assertAllEqual(tuple(var.get_shape().as_list()), expected_variables[var.name], 'missed shapes at {} dev mode'.format(var.name))
def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, extra_args=None): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, scope=scope, extra_args=extra_args) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def _input_fn(params): """Input function.""" if mode == tf.contrib.learn.ModeKeys.TRAIN: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) else: src_file = "%s.%s" % (hparams.test_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) if mode == tf.contrib.learn.ModeKeys.TRAIN: if "context" in params: batch_size = params["batch_size"] num_hosts = params["context"].num_hosts # TODO(dehao): update to use current_host once available in API. current_host = params["context"].current_input_fn_deployment( )[1] else: num_hosts = 1 current_host = 0 batch_size = hparams.batch_size mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=batch_size) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=hparams.src_max_len) return iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, output_buffer_size=None, skip_count=None, num_shards=num_hosts, shard_index=current_host, reshuffle_each_iteration=True, use_char_encode=hparams.use_char_encode, filter_oversized_sequences=True) else: return iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=hparams.infer_batch_size, eos=hparams.eos, src_max_len=hparams.src_max_len, use_char_encode=hparams.use_char_encode)
def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, extra_args=None): """Create graph, model and iterator for training.""" graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) data_dataset = tf.data.TextLineDataset(hparams.train_data) kb_dataset = tf.data.TextLineDataset(hparams.train_kb) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) reverse_vocab_table = lookup_ops.index_to_string_table_from_file( hparams.vocab_file, default_value=vocab_utils.UNK) # this is the actual train_iterator train_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # this is the placeholder iterator. One can use this placeholder iterator # to switch between training and evauation. handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types, train_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.TRAIN, vocab_table=vocab_table, scope=scope, extra_args=extra_args, reverse_vocab_table=reverse_vocab_table) return TrainModel(graph=graph, model=model, placeholder_iterator=iterator, train_iterator=train_iterator, placeholder_handle=handle, skip_count_placeholder=skip_count_placeholder)
def testGetIteratorWithShard(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c a", "f e a g", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a b", "c c", "", "b c"])) hparams = tf.contrib.training.HParams(random_seed=3, num_buckets=5, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 dataset = iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, num_shards=2, shard_index=1, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [-1, -1, 0], # "f" == unknown, "e" == unknown, a [2, 0, 3] ], # c a eos -- eos is padding features["source"]) self.assertAllEqual([3, 2], features["source_sequence_length"]) self.assertAllEqual( [ [4, 2, 2], # sos c c [4, 1, 2] ], # sos b c features["target_input"]) self.assertAllEqual( [ [2, 2, 3], # c c eos [1, 2, 3] ], # b c eos features["target_output"]) self.assertAllEqual([3, 3], features["target_sequence_length"]) with self.assertRaisesOpError("End of sequence"): sess.run(get_next)
def testTrainInput(self): try1 = BatchedInput(source=np.array([[3, 4, 5, 10, 5, 5, 5, 5, 5, 5], [6, 7, 8, 6, 9, 11, 2, 2, 2, 2]]), target_in=np.array([[1, 3, 4, 5, 10, 2, 2], [1, 6, 11, 7, 8, 9, 12]]), target_out=np.array([[3, 4, 5, 10, 2, 2, 2], [6, 11, 7, 8, 9, 12, 2]]), source_size=np.array([10, 6]), target_size=np.array([5, 7]), initializer=None, sos_token_id=1, eos_token_id=2) try2 = BatchedInput(source=np.array([[11, 10, 4, 8]]), target_in=np.array([[1, 10, 11, 12, 7, 8]]), target_out=np.array([[10, 11, 12, 7, 8, 2]]), source_size=np.array([4]), target_size=np.array([6]), initializer=None, sos_token_id=1, eos_token_id=2) with self.graph.as_default(): iterator, total_num = iterator_utils.get_iterator( 'TRAIN', filesobj=self.files, buffer_size=None, num_epochs=1, batch_size=2, debug_mode=True) table_init_op = tf.tables_initializer() self.sess.run(table_init_op) self.sess.run(iterator.initializer) res1 = self.sess.run([ iterator.source, iterator.target_in, iterator.target_out, iterator.source_size, iterator.target_size, iterator.sos_token_id, iterator.eos_token_id ]) self._check_all_equal(res1, try1, 1) res2 = self.sess.run([ iterator.source, iterator.target_in, iterator.target_out, iterator.source_size, iterator.target_size, iterator.sos_token_id, iterator.eos_token_id ]) self._check_all_equal(res2, try2, 2) self.assertAllEqual(total_num, 3, 'mismatch in total num somehow')
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file speaker_file = hparams.speaker_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) spkr_table, _ = vocab_utils.create_vocab_tables(speaker_file, "", True) src_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_spkr_dataset = tf.data.TextLineDataset(src_spk_file_placeholder) tgt_spkr_dataset = tf.data.TextLineDataset(tgt_spk_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, src_spkr_dataset, tgt_spkr_dataset, spkr_table, hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer) model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, speaker_table=spkr_table, scope=scope, extra_args=extra_args) return EvalModel(graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, src_spk_file_placeholder=src_spk_file_placeholder, tgt_spk_file_placeholder=tgt_spk_file_placeholder, iterator=iterator)
def testGetIterator(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c", "a b", "", "b c"])) hparams = tf.contrib.training.HParams(random_seed=3, num_buckets=1, eos="eos", sos="sos") batch_size = 2 src_max_len = 5 dataset = iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, global_batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) features = sess.run(get_next) self.assertAllEqual( [ [4, 2, 0, 3, 3], # c a eos -- eos is padding [4, 2, 2, 0, 3] ], # c c a features["source"]) self.assertAllEqual([4, 5], features["source_sequence_length"]) self.assertAllEqual( [ [4, 1, 2], # sos b c [4, 0, 1] ], # sos a b features["target_input"]) self.assertAllEqual( [ [1, 2, 3], # b c eos [0, 1, 3] ], # a b eos features["target_output"]) self.assertAllEqual([3, 3], features["target_sequence_length"])
def create_train_model(model_creator, hparams, scope=None, single_cell_fn=None): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(): tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file) src_dataset = tf.contrib.data.TextLineDataset(src_file) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, tgt_vocab_table, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder) # Note: One can set model_device_fn to `tf.train.replica_device_setter(ps_tasks)` for distributed training. with tf.device(model_helper.get_device_str(hparams.base_gpu)): # model_creator: 模型 model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, target_vocab_table=tgt_vocab_table, scope=scope, single_cell_fn=single_cell_fn) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" vocab_file = hparams.vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): vocab_table = vocab_utils.create_vocab_tables(vocab_file) data_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) kb_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) data_dataset = tf.data.TextLineDataset(data_file_placeholder) kb_dataset = tf.data.TextLineDataset(kb_file_placeholder) # this is the eval_actual iterator eval_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len) # this is the placeholder iterator handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, eval_iterator.output_types, eval_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.EVAL, vocab_table=vocab_table, scope=scope, extra_args=extra_args) return EvalModel(graph=graph, model=model, placeholder_iterator=iterator, placeholder_handle=handle, eval_iterator=eval_iterator, data_file_placeholder=data_file_placeholder, kb_file_placeholder=kb_file_placeholder)
def create_train_model( model_creator, hparams): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container("train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table) return TrainModel( graph=graph, model=model, iterator=iterator)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer, use_char_encode=hparams.use_char_encode) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args) return EvalModel( graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator)
def create_train_model(model_creator, hparams): """Create train graph, model, and iterator.""" src_file = hparams.src_file tgt_file = hparams.tgt_file src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container("train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder) with tf.device("/cpu:0"): model = model_creator(hparams, iterator=iterator, mode=tf.estimator.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def prepare_dataset(flags): """Generate the preprocessed dataset.""" src_file = "%s.%s" % (flags.data_dir + flags.train_prefix, flags.src) tgt_file = "%s.%s" % (flags.data_dir + flags.train_prefix, flags.tgt) vocab_file = flags.data_dir + flags.vocab_prefix _, vocab_file = vocab_utils.check_vocab(vocab_file, flags.out_dir) out_file = flags.out_dir + "preprocessed_dataset" src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( vocab_file) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=1, global_batch_size=1, sos=vocab_utils.SOS, eos=vocab_utils.EOS, random_seed=1, num_buckets=flags.num_buckets, src_max_len=flags.src_max_len, tgt_max_len=flags.tgt_max_len, filter_oversized_sequences=True, return_raw=True).make_initializable_iterator() with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(iterator.initializer) try: i = 0 while True: with open(out_file + "_%d" % i, "wb") as f: i += 1 for _ in range(100): for j in sess.run(iterator.get_next()): tf.logging.info(j) f.write(bytearray(j)) except tf.errors.OutOfRangeError: pass
def test_get_iterator(self): tf.set_random_seed(1) src_vocab_table = tgt_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c", "a b", "", "b c"])) hparams = tf.contrib.training.HParams(random_seed=3, num_buckets=1, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 iterator = iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, reshuffle_each_iteration=False, delimiter=" ") table_initializer = tf.tables_initializer() source = iterator.source src_seq_len = iterator.source_sequence_length self.assertEqual([None, None], source.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) (source_v, src_seq_len_v) = sess.run((source, src_seq_len)) print(source_v) print(src_seq_len_v) (source_v, src_seq_len_v) = sess.run((source, src_seq_len)) print(source_v) print(src_seq_len_v)
def create_test_iterator(hparams, mode): """Create test iterator.""" src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant([hparams.eos, "a", "b", "c", "d"])) tgt_vocab_mapping = tf.constant([hparams.sos, hparams.eos, "a", "b", "c"]) tgt_vocab_table = lookup_ops.index_table_from_tensor(tgt_vocab_mapping) if mode == tf.contrib.learn.ModeKeys.INFER: reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_tensor( tgt_vocab_mapping) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a a b b c", "a b b"])) if mode != tf.contrib.learn.ModeKeys.INFER: tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["a b c b c", "a b c b"])) return ( # TODO changeto accomodate new inputs iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets), src_vocab_table, tgt_vocab_table) else: return ( # TODO changeto accomodate new inputs iterator_utils.get_infer_iterator(src_dataset=src_dataset, src_vocab_table=src_vocab_table, eos=hparams.eos, batch_size=hparams.batch_size), src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table)
def mytest_iterator(): src_dataset = tf.data.TextLineDataset(hparam.train_src) tgt_dataset = tf.data.TextLineDataset(hparam.train_tgt) src, tgt_in, tgt_out, src_seq_len, tgt_seq_len, initializer, _ = get_iterator( src_dataset, tgt_dataset, hparam) with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(initializer) for i in range(1): try: _src, _tgt_in, _tgt_out, _src_seq_len, _tgt_seq_len = sess.run( [src, tgt_in, tgt_out, src_seq_len, tgt_seq_len]) print('src', _src) print('tgt_in', _tgt_in) print('tgt_out', _tgt_out) print('src_seq_len', _src_seq_len) print('tgt_seq_len', _tgt_seq_len) except tf.errors.OutOfRangeError: print('xxxxxxxxxxxxxxx') sess.run(initializer)
def setUp(self): super(ModelTest, self).setUp() self.graph = tf.Graph() self.session = tf.Session(graph=self.graph) with self.graph.as_default(): self.iterator, _ = iterator_utils.get_iterator( 'TRAIN', filesobj=TRAIN_FILES, buffer_size=TRAIN_HPARAMS.buffer_size, num_epochs=TRAIN_HPARAMS.num_epochs, batch_size=TRAIN_HPARAMS.batch_size, debug_mode=True) self.model = AttentionModel(TRAIN_HPARAMS, self.iterator, 'TRAIN') self.table_init_op = tf.tables_initializer() self.vars_init_op = tf.global_variables_initializer()
def create_eval_model(model_creator, hparams, scope=None, single_cell_fn=None): """Create train graph, model, src/tgt file holders, and iterator.""" tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(): tgt_vocab_table = vocab_utils.create_tgt_vocab_table(tgt_vocab_file) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, tgt_vocab_table, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len) model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, target_vocab_table=tgt_vocab_table, scope=scope, single_cell_fn=single_cell_fn) return EvalModel(graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator)
if __name__ == "__main__": src_dataset_name = "tst2012.en" tgt_dataset_name = "tst2012.vi" src_vocab_name = "vocab.en" tgt_vocab_name = "vocab.vi" src_dataset_path = os.path.join(BASE, src_dataset_name) tgt_dataset_path = os.path.join(BASE, tgt_dataset_name) src_vocab_path = os.path.join(BASE, src_vocab_name) tgt_vocab_path = os.path.join(BASE, tgt_vocab_name) batch_size = 128 num_buckets = 5 iterator = get_iterator(src_dataset_path=src_dataset_path, tgt_dataset_path=tgt_dataset_path, src_vocab_path=src_vocab_path, tgt_vocab_path=tgt_vocab_path, batch_size=batch_size, num_buckets=num_buckets, is_shuffle=False, src_max_len=None, tgt_max_len=None) model = EvalModel(iterator=iterator) session = tf.Session() session.run(tf.tables_initializer()) session.run(tf.global_variables_initializer()) session.run(iterator.initializer) for i in range(1000): print(i) eval_loss, predict_count, batch_size = model.eval(eval_sess=session) # print(eval_loss) # print(predict_count) # print(batch_size)
def testGetIteratorWithSkipCount(self): vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.contrib.data.Dataset.from_tensor_slices( tf.constant(["c c a", "c a", "d", "f e a g"])) tgt_dataset = tf.contrib.data.Dataset.from_tensor_slices( tf.constant(["a b", "b c", "", "c c"])) hparams = tf.contrib.training.HParams( random_seed=3, num_buckets=5, source_reverse=False, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 skip_count = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, vocab_table=vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, src_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, skip_count=skip_count) table_initializer = tf.tables_initializer() source = iterator.source target_input = iterator.target_input target_output = iterator.target_output src_seq_len = iterator.source_sequence_length tgt_seq_len = iterator.target_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None, None], target_input.shape.as_list()) self.assertEqual([None, None], target_output.shape.as_list()) self.assertEqual([None], src_seq_len.shape.as_list()) self.assertEqual([None], tgt_seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer, feed_dict={skip_count: 3}) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[-1, -1, 0]], # "f" == unknown, "e" == unknown, a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 2, 2]], # sos c c target_input_v) self.assertAllEqual( [[2, 2, 3]], # c c eos target_output_v) self.assertAllEqual([3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source) # Re-init iterator with skip_count=0. sess.run(iterator.initializer, feed_dict={skip_count: 0}) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[-1, -1, 0], # "f" == unknown, "e" == unknown, a [2, 0, 3]], # c a eos -- eos is padding source_v) self.assertAllEqual([3, 2], src_len_v) self.assertAllEqual( [[4, 2, 2], # sos c c [4, 1, 2]], # sos b c target_input_v) self.assertAllEqual( [[2, 2, 3], # c c eos [1, 2, 3]], # b c eos target_output_v) self.assertAllEqual([3, 3], tgt_len_v) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = ( sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[2, 2, 0]], # c c a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 0, 1]], # sos a b target_input_v) self.assertAllEqual( [[0, 1, 3]], # a b eos target_output_v) self.assertAllEqual([3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source)
def testTrainInput(self): try1 = BatchedInput( src_chars=np.array([[[41, 53, 52, 42, 59, 41, 47, 43, 52, 42, 53], [52, 53, 57, 53, 58, 56, 53, 57, 2, 2, 2], [51, 47, 57, 51, 53, 57, 2, 2, 2, 2, 2], [15, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], [[52, 53, 57, 58, 56, 53, 57, 2, 2, 2, 2], [57, 53, 51, 53, 57, 2, 2, 2, 2, 2, 2], [59, 52, 53, 2, 2, 2, 2, 2, 2, 2, 2], [15, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]]), trg_chars_in=np.array([[[1, 40, 54, 45, 58, 45, 50, 43, 2, 2], [1, 51, 57, 54, 55, 41, 48, 58, 41, 55], [1, 15, 2, 2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], [[1, 59, 41, 2, 2, 2, 2, 2, 2, 2], [1, 37, 54, 41, 2, 2, 2, 2, 2, 2], [1, 51, 50, 41, 2, 2, 2, 2, 2, 2], [1, 15, 2, 2, 2, 2, 2, 2, 2, 2]]]), trg_chars_out=np.array([[[40, 54, 45, 58, 45, 50, 43, 2, 2, 2], [51, 57, 54, 55, 41, 48, 58, 41, 55, 2], [15, 2, 2, 2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], [[59, 41, 2, 2, 2, 2, 2, 2, 2, 2], [37, 54, 41, 2, 2, 2, 2, 2, 2, 2], [51, 50, 41, 2, 2, 2, 2, 2, 2, 2], [15, 2, 2, 2, 2, 2, 2, 2, 2, 2]]]), trg_chars_lens=np.array([[8, 10, 2, 0], [3, 4, 4, 2]]), trg_words_in=np.array([[1, 100, 110, 9, 2], [1, 98, 348, 81, 9]]), trg_words_out=np.array([[100, 110, 9, 2, 2], [98, 348, 81, 9, 2]]), src_size=np.array([4, 4]), trg_size=np.array([4, 5]), initializer=None, sos_token_id=1, eos_token_id=2) try2 = BatchedInput(src_chars=np.array([[[46, 59, 40, 53, 2, 2], [51, 59, 41, 46, 53, 2], [42, 47, 50, 53, 45, 53], [15, 2, 2, 2, 2, 2]]]), trg_chars_in=np.array( [[[1, 56, 44, 41, 54, 41, 2], [1, 59, 37, 55, 2, 2, 2], [1, 37, 2, 2, 2, 2, 2], [1, 59, 44, 51, 48, 41, 2], [1, 48, 51, 56, 2, 2, 2], [1, 51, 42, 2, 2, 2, 2], [1, 40, 45, 37, 48, 51, 43], [1, 15, 2, 2, 2, 2, 2]]]), trg_chars_out=np.array( [[[56, 44, 41, 54, 41, 2, 2], [59, 37, 55, 2, 2, 2, 2], [37, 2, 2, 2, 2, 2, 2], [59, 44, 51, 48, 41, 2, 2], [48, 51, 56, 2, 2, 2, 2], [51, 42, 2, 2, 2, 2, 2], [40, 45, 37, 48, 51, 43, 2], [15, 2, 2, 2, 2, 2, 2]]]), trg_chars_lens=np.array([[6, 4, 2, 6, 4, 3, 7, 2]]), trg_words_in=np.array( [[1, 121, 122, 14, 836, 416, 37, 1471, 9]]), trg_words_out=np.array( [[121, 122, 14, 836, 416, 37, 1471, 9, 2]]), src_size=np.array([4]), trg_size=np.array([9]), initializer=None, sos_token_id=1, eos_token_id=2) with self.graph.as_default(): iterator, total_num = iterator_utils.get_iterator( 'TRAIN', filesobj=self.files, buffer_size=None, num_epochs=1, batch_size=2, debug_mode=True) table_init_op = tf.tables_initializer() self.sess.run(table_init_op) self.sess.run(iterator.initializer) res1 = self.sess.run([ iterator.src_chars, iterator.trg_chars_in, iterator.trg_chars_out, iterator.trg_chars_lens, iterator.trg_words_in, iterator.trg_words_out, iterator.src_size, iterator.trg_size, iterator.sos_token_id, iterator.eos_token_id ]) self._check_all_equal(res1, try1, 1) res2 = self.sess.run([ iterator.src_chars, iterator.trg_chars_in, iterator.trg_chars_out, iterator.trg_chars_lens, iterator.trg_words_in, iterator.trg_words_out, iterator.src_size, iterator.trg_size, iterator.sos_token_id, iterator.eos_token_id ]) self._check_all_equal(res2, try2, 2) self.assertAllEqual(total_num, 3, 'mismatch in total num somehow')
#graph = tf.Graph() scope="train" #with graph.as_default(), tf.container(scope): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, args.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=args.batch_size, sos=args.sos, eos=args.eos, random_seed=args.random_seed, num_buckets=args.num_buckets, src_max_len=args.src_max_len, tgt_max_len=args.tgt_max_len, skip_count=skip_count_placeholder, num_shards=args.num_workers, shard_index=args.jobid) # get source source = iterator.source if args.time_major: source = tf.transpose(source) embedding_encoder, embedding_decoder = create_emb_for_encoder_and_decoder(share_vocab=args.share_vocab, src_vocab_size=src_vocab_size,
def run(): train_hparams = misc_utils.get_train_hparams() dev_hparams = misc_utils.get_dev_hparams() test_hparams = misc_utils.get_test_hparams() train_graph = tf.Graph() dev_graph = tf.Graph() test_graph = tf.Graph() with train_graph.as_default(): train_iterator, train_total_num = iterator_utils.get_iterator( regime=train_hparams.regime, filesobj=train_hparams.filesobj, buffer_size=train_hparams.buffer_size, num_epochs=train_hparams.num_epochs, batch_size=train_hparams.batch_size) train_model = BuildTrainModel(hparams=train_hparams, iterator=train_iterator, graph=train_graph) train_vars_init_op = tf.global_variables_initializer() train_tables_init_op = tf.tables_initializer() tf.get_default_graph().finalize() with dev_graph.as_default(): dev_iterator, dev_total_num = iterator_utils.get_iterator( dev_hparams.regime, filesobj=dev_hparams.filesobj, buffer_size=dev_hparams.buffer_size, num_epochs=dev_hparams.num_epochs, batch_size=dev_hparams.batch_size) dev_model = BuildEvalModel(hparams=dev_hparams, iterator=dev_iterator, graph=dev_graph) dev_tables_init_op = tf.tables_initializer() tf.get_default_graph().finalize() with test_graph.as_default(): test_iterator, test_total_num = iterator_utils.get_iterator( test_hparams.regime, filesobj=test_hparams.filesobj, buffer_size=test_hparams.buffer_size, num_epochs=test_hparams.num_epochs, batch_size=test_hparams.batch_size) test_model = BuildInferModel( hparams=test_hparams, iterator=test_iterator, graph=test_graph, infer_file_path=test_hparams.translation_file_path, infer_chars_file_path=test_hparams.char_translation_file_path) test_tables_init_op = tf.tables_initializer() tf.get_default_graph().finalize() train_session = tf.Session(graph=train_graph) dev_session = tf.Session(graph=dev_graph) test_session = tf.Session(graph=test_graph) train_steps = misc_utils.count_num_steps(train_hparams.num_epochs, train_total_num, train_hparams.batch_size) eval_steps = misc_utils.count_num_steps(1, dev_total_num, dev_hparams.batch_size) num_test_steps = misc_utils.count_num_steps(1, test_total_num, test_hparams.batch_size) eval_count = dev_hparams.num_steps_to_eval train_session.run(train_tables_init_op) train_session.run(train_iterator.initializer) train_session.run(train_vars_init_op) with tqdm(total=train_steps) as prog: for step in range(train_steps): train_model.train(train_session) if step % eval_count == 0: dev_loss = misc_utils.eval_once(train_model, dev_model, train_session, dev_session, step, train_hparams.chkpts_dir, dev_iterator, eval_steps, dev_tables_init_op) print('dev loss at step {} = {}'.format(step, dev_loss)) prog.update(1) misc_utils.write_translations(train_model, train_session, train_hparams.chkpts_dir, step, test_model, test_session, test_tables_init_op, test_iterator, num_test_steps)
def train(): """Train a translation model.""" create_new_model = params['create_new_model'] out_dir = params['out_dir'] model_creator = nmt_model.Model # Create model graph summary_name = "train_log" # Setting up session and initilize input data iterators src_file = params['src_data_file'] tgt_file = params['tgt_data_file'] dev_src_file = params['dev_src_file'] dev_tgt_file = params['dev_tgt_file'] test_src_file = params['test_src_file'] test_tgt_file = params['test_tgt_file'] char_vocab_file = params['enc_char_map_path'] src_vocab_file = params['src_vocab_file'] tgt_vocab_file = params['tgt_vocab_file'] if(src_vocab_file == '' or src_vocab_file == ''): raise ValueError("vocab_file '%s' not given in params.") graph = tf.Graph() # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # Model run params num_epochs = params['num_epochs'] batch_size = params['batch_size'] steps_per_stats = params['steps_per_stats'] utils.print_out("# Epochs=%s, Batch Size=%s, Steps_per_Stats=%s" % (num_epochs, batch_size, steps_per_stats), None) with graph.as_default(): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, params['share_vocab']) char_vocab_table = vocab_utils.get_char_table(char_vocab_file) reverse_target_table = lookup_ops.index_to_string_table_from_file(tgt_vocab_file, default_value=params['unk']) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) batched_iter = iterator_utils.get_iterator(src_dataset, tgt_dataset, char_vocab_table, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=params['sos'], eos=params['eos'], char_pad = params['char_pad'], num_buckets=params['num_buckets'], num_epochs = params['num_epochs'], src_max_len=params['src_max_len'], tgt_max_len=params['tgt_max_len'], src_char_max_len = params['char_max_len'] ) # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),graph) # Preload validation data for decoding. dev_src_dataset = tf.data.TextLineDataset(dev_src_file) dev_tgt_dataset = tf.data.TextLineDataset(dev_tgt_file) dev_batched_iterator = iterator_utils.get_iterator(dev_src_dataset, dev_tgt_dataset, char_vocab_table, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=params['sos'], eos=params['eos'], char_pad = params['char_pad'], num_buckets=params['num_buckets'], num_epochs = params['num_epochs'], src_max_len=params['src_max_len'], tgt_max_len=params['tgt_max_len'], src_char_max_len = params['char_max_len'] ) # Preload test data for decoding. test_src_dataset = tf.data.TextLineDataset(test_src_file) test_tgt_dataset = tf.data.TextLineDataset(test_tgt_file) test_batched_iterator = iterator_utils.get_iterator(test_src_dataset, test_tgt_dataset, char_vocab_table, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=params['sos'], eos=params['eos'], char_pad = params['char_pad'], num_buckets=params['num_buckets'], num_epochs = params['num_epochs'], src_max_len=params['src_max_len'], tgt_max_len=params['tgt_max_len'], src_char_max_len = params['char_max_len'] ) config_proto = utils.get_config_proto(log_device_placement=params['log_device_placement']) sess = tf.Session(config=config_proto) with sess.as_default(): train_model = model_creator(mode = params['mode'], train_iterator = batched_iter, val_iterator = dev_batched_iterator, char_vocab_table = char_vocab_table, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_table = reverse_target_table) loaded_train_model, global_step = create_or_load_model(train_model, params['out_dir'],session=sess,name="train", log_f = log_f, create=create_new_model) sess.run([batched_iter.initializer,dev_batched_iterator.initializer, test_batched_iterator.initializer]) start_train_time = time.time() utils.print_out("# Start step %d, lr %g, %s" %(global_step, loaded_train_model.learning_rate.eval(session=sess), time.ctime()), log_f) # Reset statistics stats = init_stats() steps_per_epoch = int(np.ceil(utils.get_file_row_size(src_file) / batch_size)) utils.print_out("Total steps per epoch: %d" % steps_per_epoch) def train_step(model, sess): return model.train(sess) def dev_step(model, sess): total_steps = int(np.ceil(utils.get_file_row_size(dev_src_file) / batch_size )) total_dev_loss = 0.0 total_accuracy = 0.0 for _ in range(total_steps): dev_result_step = model.dev(sess) dev_softmax_scores, dev_loss, tgt_output_ids,_,_,_,_ = dev_result_step total_dev_loss += dev_loss * params['batch_size'] total_accuracy += evaluation_utils._accuracy(dev_softmax_scores, tgt_output_ids, None, None) return (total_dev_loss/total_steps, total_accuracy/total_steps) for epoch_step in range(num_epochs): for curr_step in range(int(np.ceil(steps_per_epoch))): start_time = time.time() step_result = train_step(loaded_train_model, sess) global_step = update_stats(stats, summary_writer, start_time, step_result) # Logging Step if(curr_step % params['steps_per_stats'] == 0): check_stats(stats, global_step, steps_per_stats, log_f) # Evaluation if(curr_step % params['steps_per_devRun'] == 0): dev_step_loss, dev_step_acc = dev_step(loaded_train_model, sess) utils.print_out("Dev Step total loss, Accuracy: %f, %f" % (dev_step_loss, dev_step_acc), log_f) utils.print_out("# Finished an epoch, epoch completed %d" % epoch_step) loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) dev_step_loss = dev_step(loaded_train_model, sess) utils.print_time("# Done training!", start_train_time) summary_writer.close()
def self_play_iterator_creator(hparams, num_workers, jobid): """create a self play iterator. There are iterators that will be created here. A supervised training iterator used for supervised learning. A full text iterator and structured iterator used for reinforcement learning self play. Full text iterators feeds data from text files while structured iterators are initialized directly from objects. The former one is used for traiing. The later one is used for self play dialogue generation to eliminate the need of serializing them into actual text files. """ vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) data_dataset = tf.data.TextLineDataset(hparams.train_data) kb_dataset = tf.data.TextLineDataset(hparams.train_kb) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) # this is the actual iterator for supervised training train_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # this is the actual iterator for self_play_fulltext_iterator data_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="src_ph") kb_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="kb_ph") batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64, name="bs_ph") dataset_data = tf.data.Dataset.from_tensor_slices(data_placeholder) kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder) self_play_fulltext_iterator = iterator_utils.get_infer_iterator( dataset_data, kb_dataset, vocab_table, batch_size=batch_size_placeholder, eod=hparams.eod, len_action=hparams.len_action, self_play=True) # this is the actual iterator for self_play_structured_iterator self_play_structured_iterator = tf.data.Iterator.from_structure( self_play_fulltext_iterator.output_types, self_play_fulltext_iterator.output_shapes) iterators = [ train_iterator, self_play_fulltext_iterator, self_play_structured_iterator ] # this is the list of placeholders placeholders = [ data_placeholder, kb_placeholder, batch_size_placeholder, skip_count_placeholder ] return iterators, placeholders
def testGetIteratorWithSkipCount(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c a", "c c a", "d", "f e a g"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["b c", "a b", "", "c c"])) hparams = tf.contrib.training.HParams( random_seed=3, num_buckets=5, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 skip_count = tf.placeholder(shape=(), dtype=tf.int64) dataset = iterator_utils.get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, skip_count=skip_count, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() iterator = dataset.make_initializable_iterator() get_next = iterator.get_next() with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer, feed_dict={skip_count: 1}) features = sess.run(get_next) self.assertAllEqual( [[-1, -1, 0], # "f" == unknown, "e" == unknown, a [2, 2, 0]], # c c a features["source"]) self.assertAllEqual([3, 3], features["source_sequence_length"]) self.assertAllEqual( [[4, 2, 2], # sos c c [4, 0, 1]], # sos a b features["target_input"]) self.assertAllEqual( [[2, 2, 3], # c c eos [0, 1, 3]], # a b eos features["target_output"]) self.assertAllEqual([3, 3], features["target_sequence_length"]) # Re-init iterator with skip_count=0. sess.run(iterator.initializer, feed_dict={skip_count: 0}) features = sess.run(get_next) self.assertAllEqual( [[-1, -1, 0], # "f" == unknown, "e" == unknown, a [2, 2, 0]], # c c a features["source"]) self.assertAllEqual([3, 3], features["source_sequence_length"]) self.assertAllEqual( [[4, 2, 2], # sos c c [4, 0, 1]], # sos a b features["target_input"]) self.assertAllEqual( [[2, 2, 3], # c c eos [0, 1, 3]], # a b eos features["target_output"]) self.assertAllEqual([3, 3], features["target_sequence_length"])
def testGetIterator(self): tf.set_random_seed(1) tgt_vocab_table = src_vocab_table = lookup_ops.index_table_from_tensor( tf.constant(["a", "b", "c", "eos", "sos"])) src_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["f e a g", "c c a", "d", "c a"])) tgt_dataset = tf.data.Dataset.from_tensor_slices( tf.constant(["c c", "a b", "", "b c"])) hparams = tf.contrib.training.HParams(random_seed=3, num_buckets=5, eos="eos", sos="sos") batch_size = 2 src_max_len = 3 iterator = iterator_utils.get_iterator(src_dataset=src_dataset, tgt_dataset=tgt_dataset, src_vocab_table=src_vocab_table, tgt_vocab_table=tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=src_max_len, reshuffle_each_iteration=False) table_initializer = tf.tables_initializer() source = iterator.source target_input = iterator.target_input target_output = iterator.target_output src_seq_len = iterator.source_sequence_length tgt_seq_len = iterator.target_sequence_length self.assertEqual([None, None], source.shape.as_list()) self.assertEqual([None, None], target_input.shape.as_list()) self.assertEqual([None, None], target_output.shape.as_list()) self.assertEqual([None], src_seq_len.shape.as_list()) self.assertEqual([None], tgt_seq_len.shape.as_list()) with self.test_session() as sess: sess.run(table_initializer) sess.run(iterator.initializer) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [ [2, 0, 3], # c a eos -- eos is padding [-1, -1, 0] ], # "f" == unknown, "e" == unknown, a source_v) self.assertAllEqual([2, 3], src_len_v) self.assertAllEqual( [ [4, 1, 2], # sos b c [4, 2, 2] ], # sos c c target_input_v) self.assertAllEqual( [ [1, 2, 3], # b c eos [2, 2, 3] ], # c c eos target_output_v) self.assertAllEqual([3, 3], tgt_len_v) (source_v, src_len_v, target_input_v, target_output_v, tgt_len_v) = (sess.run((source, src_seq_len, target_input, target_output, tgt_seq_len))) self.assertAllEqual( [[2, 2, 0]], # c c a source_v) self.assertAllEqual([3], src_len_v) self.assertAllEqual( [[4, 0, 1]], # sos a b target_input_v) self.assertAllEqual( [[0, 1, 3]], # a b eos target_output_v) self.assertAllEqual([3], tgt_len_v) with self.assertRaisesOpError("End of sequence"): sess.run(source)