def create_infer_model(model_creator, hparams, scope=None, extra_args=None): """Create inference model.""" graph = tf.Graph() src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file speaker_file = hparams.speaker_file with graph.as_default(), tf.container(scope or "infer"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) speaker_table, _ = vocab_utils.create_vocab_tables( speaker_file, speaker_file, True) reverse_speaker_table = lookup_ops.index_to_string_table_from_file( speaker_file, default_value=vocab_utils.UNK) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) src_spkr_placeholder = tf.placeholder(shape=[None], dtype=tf.string) tgt_spkr_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) src_spkr_dataset = tf.data.Dataset.from_tensor_slices( src_spkr_placeholder) tgt_spkr_dataset = tf.data.Dataset.from_tensor_slices( tgt_spkr_placeholder) iterator = iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, src_spkr_dataset, tgt_spkr_dataset, speaker_table, batch_size=batch_size_placeholder, eos=hparams.eos, src_max_len=hparams.src_max_len_infer) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.INFER, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, speaker_table=speaker_table, reverse_speaker_table=reverse_speaker_table, scope=scope, extra_args=extra_args) return InferModel(graph=graph, model=model, src_placeholder=src_placeholder, batch_size_placeholder=batch_size_placeholder, src_speaker_placeholder=src_spkr_placeholder, tgt_speaker_placeholder=tgt_spkr_placeholder, iterator=iterator)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file speaker_file = hparams.speaker_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) spkr_table, _ = vocab_utils.create_vocab_tables(speaker_file, "", True) src_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_spk_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_spkr_dataset = tf.data.TextLineDataset(src_spk_file_placeholder) tgt_spkr_dataset = tf.data.TextLineDataset(tgt_spk_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, src_spkr_dataset, tgt_spkr_dataset, spkr_table, hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer) model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, speaker_table=spkr_table, scope=scope, extra_args=extra_args) return EvalModel(graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, src_spk_file_placeholder=src_spk_file_placeholder, tgt_spk_file_placeholder=tgt_spk_file_placeholder, iterator=iterator)
def create_graph(cfg, graph, mode): colorlog.info("Build %s graph" % mode) with graph.as_default(), tf.container(mode): if cfg.data_name == "audiocaps": prefetch_data_fn = input_helper.prefetch_dataset vocab_fname = os.path.join('data/audiocaps/features/auxiliary', '{}.vocab'.format(cfg.vocab_size)) else: raise NotImplementedError() # Read vocab _, index_to_string = vocab_utils.create_vocab_tables( vocab_fname, cfg.vocab_size) # Read dataset num_data, iterator_init, iterators = prefetch_data_fn( cfg.batch_size, cfg.bucket_width, cfg.buffer_size, cfg.random_seed, cfg.num_gpus, cfg.num_epochs, mode, cfg.feature_name, ) iters_in_data = int(num_data / cfg.batch_size / cfg.num_gpus) # Build model model_args = cfg, index_to_string, iterators, iters_in_data, mode == "train" if cfg.model_name == "PyramidLSTM": model = PyramidLSTM(*model_args) else: raise NotImplementedError() return model, iters_in_data, iterator_init
def build_inference_graph(model_creator, config): graph = tf.Graph() with graph.as_default(): src_file = os.path.join(config.data_dir, "%s.%s" % (config.test_prefix, config.src)) tgt_file = os.path.join(config.data_dir, "%s.%s" % (config.test_prefix, config.tgt)) src_vocab_file = os.path.join(config.data_dir, "%s.%s" % (config.vocab_prefix, config.src)) tgt_vocab_file = os.path.join(config.data_dir, "%s.%s" % (config.vocab_prefix, config.tgt)) src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table = \ vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, config) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) src_dataset = tf.contrib.data.Dataset.from_tensor_slices( src_placeholder) iterator = input_pipeline.get_test_iterator( src_dataset, src_vocab_table, batch_size_placeholder, config) model = model_creator(config, iterator, "test", tgt_vocab_table, reverse_tgt_vocab_table) return model_base.Model(graph=graph, model=model, iterator=iterator, src_file=src_file, tgt_file=tgt_file, src_placeholder=src_placeholder, batch_size_placeholder=batch_size_placeholder)
def _input_fn(params): """Input function.""" del params if mode == tf.contrib.learn.ModeKeys.TRAIN: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) else: if hparams.mode == "translate": src_file = hparams.translate_file + ".tok" tgt_file = hparams.translate_file + ".tok" else: src_file = "%s.%s" % (hparams.test_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) if mode == tf.contrib.learn.ModeKeys.TRAIN: # Run one epoch and stop if running train_and_eval. if hparams.mode == "train_and_eval": # In this mode input pipeline is restarted every epoch, so choose a # different random_seed. num_repeat = 1 random_seed = hparams.random_seed + int(time.time()) % 100 else: num_repeat = 8 random_seed = hparams.random_seed return iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, output_buffer_size=None, skip_count=None, num_shards=1, # flags.num_workers shard_index=0, # flags.jobid reshuffle_each_iteration=True, use_char_encode=hparams.use_char_encode, num_repeat=num_repeat, filter_oversized_sequences=True ) # need to update get_effective_train_epoch_size() if this flag flips. else: return iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=hparams.infer_batch_size, eos=hparams.eos, src_max_len=hparams.src_max_len, use_char_encode=hparams.use_char_encode)
def train_input_fn(params, num_workers=1, jobid=0): src_file = "%s.%s" % (params.train_prefix, params.src) tgt_file = "%s.%s" % (params.train_prefix, params.tgt) src_vocab_file = params.src_vocab_file tgt_vocab_file = params.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, params.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) dataset = get_dataset(src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=params.batch_size, sos=params.sos, eos=params.eos, random_seed=params.random_seed, num_buckets=params.num_buckets, src_max_len=params.src_max_len, tgt_max_len=params.tgt_max_len, num_shards=num_workers, shard_index=jobid) batched_iter = dataset.make_one_shot_iterator() return batched_iter.get_next()
def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, extra_args=None): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # Note: One can set model_device_fn to # `tf.train.replica_device_setter(ps_tasks)` for distributed training. model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, scope=scope, extra_args=extra_args) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def create_train_model(model_creator, hparams, scope=None, num_workers=1, jobid=0, extra_args=None): """Create graph, model and iterator for training.""" graph = tf.Graph() with graph.as_default(), tf.container(scope or "train"): vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) data_dataset = tf.data.TextLineDataset(hparams.train_data) kb_dataset = tf.data.TextLineDataset(hparams.train_kb) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) reverse_vocab_table = lookup_ops.index_to_string_table_from_file( hparams.vocab_file, default_value=vocab_utils.UNK) # this is the actual train_iterator train_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # this is the placeholder iterator. One can use this placeholder iterator # to switch between training and evauation. handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types, train_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.TRAIN, vocab_table=vocab_table, scope=scope, extra_args=extra_args, reverse_vocab_table=reverse_vocab_table) return TrainModel(graph=graph, model=model, placeholder_iterator=iterator, train_iterator=train_iterator, placeholder_handle=handle, skip_count_placeholder=skip_count_placeholder)
def create_selfplay_model(model_creator, is_mutable, num_workers, jobid, hparams, scope=None, extra_args=None): """create slef play models.""" graph = tf.Graph() with graph.as_default(), tf.container(scope or "selfplay"): vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) reverse_vocab_table = lookup_ops.index_to_string_table_from_file( hparams.vocab_file, default_value=vocab_utils.UNK) if is_mutable: mutable_index = 0 else: mutable_index = 1 # get a list of iterators and placeholders iterators, placeholders = self_play_iterator_creator( hparams, num_workers, jobid) train_iterator, self_play_fulltext_iterator, self_play_structured_iterator = iterators data_placeholder, kb_placeholder, batch_size_placeholder, skip_count_placeholder = placeholders # get an iterator handler handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types, train_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model_device_fn = None if extra_args: model_device_fn = extra_args.model_device_fn with tf.device(model_device_fn): model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=[ dialogue_utils.mode_self_play_mutable, dialogue_utils.mode_self_play_immutable ][mutable_index], vocab_table=vocab_table, reverse_vocab_table=reverse_vocab_table, scope=scope, extra_args=extra_args) return SelfplayModel(graph=graph, model=model, placeholder_iterator=iterator, placeholder_handle=handle, train_iterator=train_iterator, self_play_ft_iterator=self_play_fulltext_iterator, self_play_st_iterator=self_play_structured_iterator, data_placeholder=data_placeholder, kb_placeholder=kb_placeholder, skip_count_placeholder=skip_count_placeholder, batch_size_placeholder=batch_size_placeholder)
def _input_fn(params): """Input function.""" if mode == tf.contrib.learn.ModeKeys.TRAIN: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) else: src_file = "%s.%s" % (hparams.test_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) if mode == tf.contrib.learn.ModeKeys.TRAIN: if "context" in params: batch_size = params["batch_size"] num_hosts = params["context"].num_hosts # TODO(dehao): update to use current_host once available in API. current_host = params["context"].current_input_fn_deployment( )[1] else: num_hosts = 1 current_host = 0 batch_size = hparams.batch_size mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=batch_size) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=hparams.src_max_len) return iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, output_buffer_size=None, skip_count=None, num_shards=num_hosts, shard_index=current_host, reshuffle_each_iteration=True, use_char_encode=hparams.use_char_encode, filter_oversized_sequences=True) else: return iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=hparams.infer_batch_size, eos=hparams.eos, src_max_len=hparams.src_max_len, use_char_encode=hparams.use_char_encode)
def create_infer_model(model_creator, hparams, scope=None, extra_args=None): """Create inference model.""" graph = tf.Graph() with graph.as_default(), tf.container(scope or "infer"): vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) reverse_vocab_table = lookup_ops.index_to_string_table_from_file( hparams.vocab_file, default_value=vocab_utils.UNK) data_src_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="src_ph") kb_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="kb_ph") batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64, name="bs_ph") data_src_dataset = tf.data.Dataset.from_tensor_slices( data_src_placeholder) kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder) # this is the actual infer iterator infer_iterator = iterator_utils.get_infer_iterator( data_src_dataset, kb_dataset, vocab_table, batch_size=batch_size_placeholder, eod=hparams.eod, len_action=hparams.len_action) # this is the placeholder infer iterator handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, infer_iterator.output_types, infer_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.INFER, vocab_table=vocab_table, reverse_vocab_table=reverse_vocab_table, scope=scope, extra_args=extra_args) return InferModel(graph=graph, model=model, placeholder_iterator=iterator, placeholder_handle=handle, infer_iterator=infer_iterator, data_src_placeholder=data_src_placeholder, kb_placeholder=kb_placeholder, batch_size_placeholder=batch_size_placeholder)
def _get_tgt_sos_eos_id(hparams): with tf.Session() as sess: _, tgt_vocab_table = vocab_utils.create_vocab_tables( hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab) tgt_sos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32) tgt_eos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32) sess.run(tf.tables_initializer()) tgt_sos_id = sess.run(tgt_sos_id, {}) tgt_eos_id = sess.run(tgt_eos_id, {}) return tgt_sos_id, tgt_eos_id
def create_train_model(model_creator, get_iterator, hparams, scope=None): """Create the training graph, model and iterator""" # Get the files by concatting prefixes and outputs. src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) vocab_file = hparams.vocab_file # Define the graph graph = tf.Graph() with graph.as_default(): vocab_table = vocab_utils.create_vocab_tables(vocab_file) # Create datasets from file src_dataset = tf.contrib.data.TextLineDataset(src_file) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) # The number of elements of this dataset that should be skipped to form the new dataset. skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) # Iterator iterator = get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, vocab_table=vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, src_reverse=hparams.src_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder ) # Model. We don't give ids_to_words arg because we don't need it for training model = model_creator( hparams=hparams, mode=tf.contrib.learn.ModeKeys.TRAIN, iterator=iterator, vocab_table=vocab_table, scope=scope ) return TrainModel( graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder )
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" vocab_file = hparams.vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): vocab_table = vocab_utils.create_vocab_tables(vocab_file) data_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) kb_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) data_dataset = tf.data.TextLineDataset(data_file_placeholder) kb_dataset = tf.data.TextLineDataset(kb_file_placeholder) # this is the eval_actual iterator eval_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len) # this is the placeholder iterator handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, eval_iterator.output_types, eval_iterator.output_shapes) batched_iterator = iterator_utils.get_batched_iterator(iterator) model = model_creator(hparams, iterator=batched_iterator, handle=handle, mode=tf.contrib.learn.ModeKeys.EVAL, vocab_table=vocab_table, scope=scope, extra_args=extra_args) return EvalModel(graph=graph, model=model, placeholder_iterator=iterator, placeholder_handle=handle, eval_iterator=eval_iterator, data_file_placeholder=data_file_placeholder, kb_file_placeholder=kb_file_placeholder)
def create_eval_model(model_creator, hparams, scope=None, extra_args=None): """Create train graph, model, src/tgt file holders, and iterator.""" src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container(scope or "eval"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer, use_char_encode=hparams.use_char_encode) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table, scope=scope, extra_args=extra_args) return EvalModel( graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator)
def create_eval_model(model_creator, get_iterator, hparams, scope=None): """Create train graph, model, src/tgt file holders, and iterator.""" vocab_file = hparams.vocab_file # Define the graph graph = tf.Graph() with graph.as_default(): vocab_table = vocab_utils.create_vocab_tables(vocab_file) # Create placeholders for the file location src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) # Create the datasets from file src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder) # Create the iterator for the dataset. We do not use skip_count here as we evaluate on the full file iterator = get_iterator( src_dataset=src_dataset, tgt_dataset=tgt_dataset, vocab_table=vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, src_reverse=hparams.src_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len_infer, tgt_max_len=hparams.tgt_max_len_infer, ) # Create a simple model model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.EVAL, vocab_table=vocab_table, scope=scope ) return EvalModel( graph=graph, model=model, src_file_placeholder=src_file_placeholder, tgt_file_placeholder=tgt_file_placeholder, iterator=iterator )
def create_train_model( model_creator, hparams): """Create train graph, model, and iterator.""" src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container("train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, source_reverse=hparams.source_reverse, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len) model = model_creator( hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table) return TrainModel( graph=graph, model=model, iterator=iterator)
def create_infer_model(model_creator, get_infer_iterator, hparams, verbose=True, scope=None): """Create the inference model""" graph = tf.Graph() vocab_file = hparams.vocab_file with graph.as_default(): # Create the lookup tables vocab_table = vocab_utils.create_vocab_tables(vocab_file) ids_to_words = lookup_ops.index_to_string_table_from_file( vocabulary_file=vocab_file, default_value=vocab_utils.UNK ) # Define data placeholders src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) # Create the dataset and iterator src_dataset = tf.contrib.data.Dataset.from_tensor_slices( src_placeholder) iterator = get_infer_iterator( dataset=src_dataset, vocab_table=vocab_table, batch_size=batch_size_placeholder, src_reverse=hparams.src_reverse, eos=hparams.eos, src_max_len=hparams.src_max_len_infer ) # Create the model model = model_creator( hparams=hparams, iterator=iterator, mode=tf.contrib.learn.ModeKeys.INFER, vocab_table=vocab_table, verbose=verbose, ids_to_words=ids_to_words, scope=scope) return InferModel( graph=graph, model=model, src_placeholder=src_placeholder, batch_size_placeholder=batch_size_placeholder, iterator=iterator)
def create_train_model(model_creator, hparams): """Create train graph, model, and iterator.""" src_file = hparams.src_file tgt_file = hparams.tgt_file src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file graph = tf.Graph() with graph.as_default(), tf.container("train"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=hparams.batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, skip_count=skip_count_placeholder) with tf.device("/cpu:0"): model = model_creator(hparams, iterator=iterator, mode=tf.estimator.ModeKeys.TRAIN, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table) return TrainModel(graph=graph, model=model, iterator=iterator, skip_count_placeholder=skip_count_placeholder)
def prepare_dataset(flags): """Generate the preprocessed dataset.""" src_file = "%s.%s" % (flags.data_dir + flags.train_prefix, flags.src) tgt_file = "%s.%s" % (flags.data_dir + flags.train_prefix, flags.tgt) vocab_file = flags.data_dir + flags.vocab_prefix _, vocab_file = vocab_utils.check_vocab(vocab_file, flags.out_dir) out_file = flags.out_dir + "preprocessed_dataset" src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( vocab_file) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=1, global_batch_size=1, sos=vocab_utils.SOS, eos=vocab_utils.EOS, random_seed=1, num_buckets=flags.num_buckets, src_max_len=flags.src_max_len, tgt_max_len=flags.tgt_max_len, filter_oversized_sequences=True, return_raw=True).make_initializable_iterator() with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(iterator.initializer) try: i = 0 while True: with open(out_file + "_%d" % i, "wb") as f: i += 1 for _ in range(100): for j in sess.run(iterator.get_next()): tf.logging.info(j) f.write(bytearray(j)) except tf.errors.OutOfRangeError: pass
def create_infer_model(model_creator, hparams): """Create inference model.""" graph = tf.Graph() src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file with graph.as_default(), tf.container("infer"): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) reverse_tgt_vocab_table = lookup_ops.index_to_string_table_from_file( tgt_vocab_file, default_value=vocab_utils.UNK) src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) src_dataset = tf.data.Dataset.from_tensor_slices(src_placeholder) iterator = iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=batch_size_placeholder, eos=hparams.eos, src_max_len=hparams.src_max_len_infer) model = model_creator( hparams, iterator=iterator, mode=tf.estimator.ModeKeys.PREDICT, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_vocab_table=reverse_tgt_vocab_table) return InferModel(graph=graph, model=model, src_placeholder=src_placeholder, batch_size_placeholder=batch_size_placeholder, iterator=iterator)
def eval_input_fn(params): src_vocab_file = params.src_vocab_file tgt_vocab_file = params.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, params.share_vocab) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) src_dataset = tf.data.TextLineDataset(src_file_placeholder) tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) dataset = get_dataset( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, params.batch_size, sos=params.sos, eos=params.eos, random_seed=params.random_seed, num_buckets=params.num_buckets, src_max_len=params.src_max_len_infer, tgt_max_len=params.tgt_max_len_infer) batched_iter = dataset.make_one_shot_iterator() return batched_iter.get_next()
def _build_decoder(self, encoder_outputs, encoder_state, hparams): """Build and run a RNN decoder with a final projection layer. Args: encoder_outputs: The outputs of encoder for every time step. encoder_state: The final state of the encoder. hparams: The Hyperparameters configurations. Returns: A tuple of final logits and final decoder state: logits: size [time, batch_size, vocab_size] when time_major=True. """ ## Decoder. with tf.variable_scope("decoder") as decoder_scope: ## Train or eval if self.mode != tf.contrib.learn.ModeKeys.INFER: # [batch, time] target_input = self.features["target_input"] if self.time_major: # If using time_major mode, then target_input should be [time, batch] # then the decoder_emb_inp would be [time, batch, dim] target_input = tf.transpose(target_input) decoder_emb_inp = tf.cast( tf.nn.embedding_lookup(self.embedding_decoder, target_input), self.dtype) if not hparams.use_fused_lstm_dec: cell, decoder_initial_state = self._build_decoder_cell( hparams, encoder_outputs, encoder_state, self.features["source_sequence_length"]) if hparams.use_dynamic_rnn: final_rnn_outputs, _ = tf.nn.dynamic_rnn( cell, decoder_emb_inp, sequence_length=self. features["target_sequence_length"], initial_state=decoder_initial_state, dtype=self.dtype, scope=decoder_scope, parallel_iterations=hparams.parallel_iterations, time_major=self.time_major) else: final_rnn_outputs, _ = tf.contrib.recurrent.functional_rnn( cell, decoder_emb_inp, sequence_length=tf.to_int32( self.features["target_sequence_length"]), initial_state=decoder_initial_state, dtype=self.dtype, scope=decoder_scope, time_major=self.time_major, use_tpu=False) else: if hparams.pass_hidden_state: decoder_initial_state = encoder_state else: decoder_initial_state = tuple( (tf.nn.rnn_cell.LSTMStateTuple( tf.zeros_like(s[0]), tf.zeros_like(s[1])) for s in encoder_state)) final_rnn_outputs = self._build_decoder_fused_for_training( encoder_outputs, decoder_initial_state, decoder_emb_inp, self.hparams) # We chose to apply the output_layer to all timesteps for speed: # 10% improvements for small models & 20% for larger ones. # If memory is a concern, we should apply output_layer per timestep. logits = self.output_layer(final_rnn_outputs) sample_id = None ## Inference else: cell, decoder_initial_state = self._build_decoder_cell( hparams, encoder_outputs, encoder_state, self.features["source_sequence_length"]) assert hparams.infer_mode == "beam_search" _, tgt_vocab_table = vocab_utils.create_vocab_tables( hparams.src_vocab_file, hparams.tgt_vocab_file, hparams.share_vocab) tgt_sos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.sos)), tf.int32) tgt_eos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(hparams.eos)), tf.int32) start_tokens = tf.fill([self.batch_size], tgt_sos_id) end_token = tgt_eos_id beam_width = hparams.beam_width length_penalty_weight = hparams.length_penalty_weight coverage_penalty_weight = hparams.coverage_penalty_weight my_decoder = beam_search_decoder.BeamSearchDecoder( cell=cell, embedding=self.embedding_decoder, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state, beam_width=beam_width, output_layer=self.output_layer, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=coverage_penalty_weight) # maximum_iteration: The maximum decoding steps. maximum_iterations = self._get_infer_maximum_iterations( hparams, self.features["source_sequence_length"]) # Dynamic decoding outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( my_decoder, maximum_iterations=maximum_iterations, output_time_major=self.time_major, swap_memory=True, scope=decoder_scope) logits = tf.no_op() sample_id = outputs.predicted_ids return logits, sample_id
def train(): """Train a translation model.""" create_new_model = params['create_new_model'] out_dir = params['out_dir'] model_creator = nmt_model.Model # Create model graph summary_name = "train_log" # Setting up session and initilize input data iterators src_file = params['src_data_file'] tgt_file = params['tgt_data_file'] dev_src_file = params['dev_src_file'] dev_tgt_file = params['dev_tgt_file'] test_src_file = params['test_src_file'] test_tgt_file = params['test_tgt_file'] char_vocab_file = params['enc_char_map_path'] src_vocab_file = params['src_vocab_file'] tgt_vocab_file = params['tgt_vocab_file'] if(src_vocab_file == '' or src_vocab_file == ''): raise ValueError("vocab_file '%s' not given in params.") graph = tf.Graph() # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # Model run params num_epochs = params['num_epochs'] batch_size = params['batch_size'] steps_per_stats = params['steps_per_stats'] utils.print_out("# Epochs=%s, Batch Size=%s, Steps_per_Stats=%s" % (num_epochs, batch_size, steps_per_stats), None) with graph.as_default(): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, params['share_vocab']) char_vocab_table = vocab_utils.get_char_table(char_vocab_file) reverse_target_table = lookup_ops.index_to_string_table_from_file(tgt_vocab_file, default_value=params['unk']) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) batched_iter = iterator_utils.get_iterator(src_dataset, tgt_dataset, char_vocab_table, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=params['sos'], eos=params['eos'], char_pad = params['char_pad'], num_buckets=params['num_buckets'], num_epochs = params['num_epochs'], src_max_len=params['src_max_len'], tgt_max_len=params['tgt_max_len'], src_char_max_len = params['char_max_len'] ) # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),graph) # Preload validation data for decoding. dev_src_dataset = tf.data.TextLineDataset(dev_src_file) dev_tgt_dataset = tf.data.TextLineDataset(dev_tgt_file) dev_batched_iterator = iterator_utils.get_iterator(dev_src_dataset, dev_tgt_dataset, char_vocab_table, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=params['sos'], eos=params['eos'], char_pad = params['char_pad'], num_buckets=params['num_buckets'], num_epochs = params['num_epochs'], src_max_len=params['src_max_len'], tgt_max_len=params['tgt_max_len'], src_char_max_len = params['char_max_len'] ) # Preload test data for decoding. test_src_dataset = tf.data.TextLineDataset(test_src_file) test_tgt_dataset = tf.data.TextLineDataset(test_tgt_file) test_batched_iterator = iterator_utils.get_iterator(test_src_dataset, test_tgt_dataset, char_vocab_table, src_vocab_table, tgt_vocab_table, batch_size=batch_size, sos=params['sos'], eos=params['eos'], char_pad = params['char_pad'], num_buckets=params['num_buckets'], num_epochs = params['num_epochs'], src_max_len=params['src_max_len'], tgt_max_len=params['tgt_max_len'], src_char_max_len = params['char_max_len'] ) config_proto = utils.get_config_proto(log_device_placement=params['log_device_placement']) sess = tf.Session(config=config_proto) with sess.as_default(): train_model = model_creator(mode = params['mode'], train_iterator = batched_iter, val_iterator = dev_batched_iterator, char_vocab_table = char_vocab_table, source_vocab_table=src_vocab_table, target_vocab_table=tgt_vocab_table, reverse_target_table = reverse_target_table) loaded_train_model, global_step = create_or_load_model(train_model, params['out_dir'],session=sess,name="train", log_f = log_f, create=create_new_model) sess.run([batched_iter.initializer,dev_batched_iterator.initializer, test_batched_iterator.initializer]) start_train_time = time.time() utils.print_out("# Start step %d, lr %g, %s" %(global_step, loaded_train_model.learning_rate.eval(session=sess), time.ctime()), log_f) # Reset statistics stats = init_stats() steps_per_epoch = int(np.ceil(utils.get_file_row_size(src_file) / batch_size)) utils.print_out("Total steps per epoch: %d" % steps_per_epoch) def train_step(model, sess): return model.train(sess) def dev_step(model, sess): total_steps = int(np.ceil(utils.get_file_row_size(dev_src_file) / batch_size )) total_dev_loss = 0.0 total_accuracy = 0.0 for _ in range(total_steps): dev_result_step = model.dev(sess) dev_softmax_scores, dev_loss, tgt_output_ids,_,_,_,_ = dev_result_step total_dev_loss += dev_loss * params['batch_size'] total_accuracy += evaluation_utils._accuracy(dev_softmax_scores, tgt_output_ids, None, None) return (total_dev_loss/total_steps, total_accuracy/total_steps) for epoch_step in range(num_epochs): for curr_step in range(int(np.ceil(steps_per_epoch))): start_time = time.time() step_result = train_step(loaded_train_model, sess) global_step = update_stats(stats, summary_writer, start_time, step_result) # Logging Step if(curr_step % params['steps_per_stats'] == 0): check_stats(stats, global_step, steps_per_stats, log_f) # Evaluation if(curr_step % params['steps_per_devRun'] == 0): dev_step_loss, dev_step_acc = dev_step(loaded_train_model, sess) utils.print_out("Dev Step total loss, Accuracy: %f, %f" % (dev_step_loss, dev_step_acc), log_f) utils.print_out("# Finished an epoch, epoch completed %d" % epoch_step) loaded_train_model.saver.save(sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) dev_step_loss = dev_step(loaded_train_model, sess) utils.print_time("# Done training!", start_train_time) summary_writer.close()
def self_play_iterator_creator(hparams, num_workers, jobid): """create a self play iterator. There are iterators that will be created here. A supervised training iterator used for supervised learning. A full text iterator and structured iterator used for reinforcement learning self play. Full text iterators feeds data from text files while structured iterators are initialized directly from objects. The former one is used for traiing. The later one is used for self play dialogue generation to eliminate the need of serializing them into actual text files. """ vocab_table = vocab_utils.create_vocab_tables(hparams.vocab_file) data_dataset = tf.data.TextLineDataset(hparams.train_data) kb_dataset = tf.data.TextLineDataset(hparams.train_kb) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) # this is the actual iterator for supervised training train_iterator = iterator_utils.get_iterator( data_dataset, kb_dataset, vocab_table, batch_size=hparams.batch_size, t1=hparams.t1, t2=hparams.t2, eod=hparams.eod, len_action=hparams.len_action, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, max_dialogue_len=hparams.max_dialogue_len, skip_count=skip_count_placeholder, num_shards=num_workers, shard_index=jobid) # this is the actual iterator for self_play_fulltext_iterator data_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="src_ph") kb_placeholder = tf.placeholder(shape=[None], dtype=tf.string, name="kb_ph") batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64, name="bs_ph") dataset_data = tf.data.Dataset.from_tensor_slices(data_placeholder) kb_dataset = tf.data.Dataset.from_tensor_slices(kb_placeholder) self_play_fulltext_iterator = iterator_utils.get_infer_iterator( dataset_data, kb_dataset, vocab_table, batch_size=batch_size_placeholder, eod=hparams.eod, len_action=hparams.len_action, self_play=True) # this is the actual iterator for self_play_structured_iterator self_play_structured_iterator = tf.data.Iterator.from_structure( self_play_fulltext_iterator.output_types, self_play_fulltext_iterator.output_shapes) iterators = [ train_iterator, self_play_fulltext_iterator, self_play_structured_iterator ] # this is the list of placeholders placeholders = [ data_placeholder, kb_placeholder, batch_size_placeholder, skip_count_placeholder ] return iterators, placeholders
def _make_distributed_pipeline(hparams, num_hosts): """Makes the distributed input pipeline. make_distributed_pipeline must be used in the PER_HOST_V1 configuration. Note: we return both the input function and the hook because MultiDeviceIterator is not compatible with Estimator / TPUEstimator. Args: hparams: The hyperparameters to use. num_hosts: The number of hosts we're running across. Returns: A MultiDeviceIterator. """ # TODO: Merge with the original copy in iterator_utils.py. # pylint: disable=g-long-lambda,line-too-long global_batch_size = hparams.batch_size if global_batch_size % num_hosts != 0: raise ValueError( "global_batch_size (%s) must be a multiple of num_hosts (%s)" % (global_batch_size, num_hosts)) # Optionally choose from `choose_buckets` buckets simultaneously. if hparams.choose_buckets: window_batch_size = int(global_batch_size / hparams.choose_buckets) else: window_batch_size = global_batch_size per_host_batch_size = global_batch_size / num_hosts output_buffer_size = global_batch_size * 100 with tf.device("/job:worker/replica:0/task:0/device:CPU:0"): # From estimator.py src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file tgt_vocab_file = hparams.tgt_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_dataset = tf.data.TextLineDataset(src_file).prefetch( output_buffer_size) tgt_dataset = tf.data.TextLineDataset(tgt_file).prefetch( output_buffer_size) mlperf_log.gnmt_print(key=mlperf_log.INPUT_BATCH_SIZE, value=global_batch_size) mlperf_log.gnmt_print(key=mlperf_log.TRAIN_HP_MAX_SEQ_LEN, value=hparams.src_max_len) # Define local variables that are parameters in iterator_utils.make_input_fn sos = hparams.sos eos = hparams.eos random_seed = hparams.random_seed num_buckets = hparams.num_buckets src_max_len = hparams.src_max_len tgt_max_len = hparams.tgt_max_len num_parallel_calls = 100 # constant in iterator_utils.py skip_count = None # constant in estimator.py reshuffle_each_iteration = True # constant in estimator.py use_char_encode = hparams.use_char_encode filter_oversized_sequences = True # constant in estimator.py # From iterator_utils.py if use_char_encode: src_eos_id = vocab_utils.EOS_CHAR_ID else: src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) mlperf_log.gnmt_print(key=mlperf_log.INPUT_SHARD, value=1) if skip_count is not None: src_tgt_dataset = src_tgt_dataset.skip(skip_count) def map_fn_1(src, tgt): src = tf.string_split([src]).values tgt = tf.string_split([tgt]).values src_size = tf.size(src) tgt_size = tf.size(tgt) size_ok_bool = tf.logical_and(src_size > 0, tgt_size > 0) if filter_oversized_sequences: oversized = tf.logical_and(src_size < src_max_len, tgt_size < tgt_max_len) size_ok_bool = tf.logical_and(size_ok_bool, oversized) if src_max_len: src = src[:src_max_len] if tgt_max_len: tgt = tgt[:tgt_max_len] return (src, tgt, size_ok_bool) src_tgt_bool_dataset = src_tgt_dataset.map( map_fn_1, num_parallel_calls=num_parallel_calls) src_tgt_bool_dataset = src_tgt_bool_dataset.filter( lambda src, tgt, filter_bool: filter_bool) def map_fn_2(src, tgt, unused_filter_bool): if use_char_encode: src = tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]) tgt = tf.cast(tgt_vocab_table.lookup(tgt), tf.int32) else: src = tf.cast(src_vocab_table.lookup(src), tf.int32) tgt = tf.cast(tgt_vocab_table.lookup(tgt), tf.int32) # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>. tgt_in = tf.concat(([tgt_sos_id], tgt), 0) tgt_out = tf.concat((tgt, [tgt_eos_id]), 0) # Add in sequence lengths. if use_char_encode: src_len = tf.to_int32( tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN) else: src_len = tf.size(src) tgt_len = tf.size(tgt_in) return src, tgt_in, tgt_out, src_len, tgt_len # Convert the word strings to ids. Word strings that are not in the # vocab get the lookup table's default_value integer. mlperf_log.gnmt_print(key=mlperf_log.PREPROC_TOKENIZE_TRAINING) src_tgt_dataset = src_tgt_bool_dataset.map( map_fn_2, num_parallel_calls=num_parallel_calls) src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size) src_tgt_dataset = src_tgt_dataset.cache() src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, random_seed, reshuffle_each_iteration).repeat() # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) def batching_func(x): return x.padded_batch( window_batch_size, # The first three entries are the source and target line rows; # these have unknown-length vectors. The last two entries are # the source and target row sizes; these are scalars. padded_shapes=( tf.TensorShape([src_max_len]), # src tf.TensorShape([tgt_max_len]), # tgt_input tf.TensorShape([tgt_max_len]), # tgt_output tf.TensorShape([]), # src_len tf.TensorShape([])), # tgt_len # Pad the source and target sequences with eos tokens. # (Though notice we don't generally need to do this since # later on we will be masking out calculations past the true sequence. padding_values=( src_eos_id, # src tgt_eos_id, # tgt_input tgt_eos_id, # tgt_output 0, # src_len -- unused 0), # For TPU, must set drop_remainder to True or batch size will be None drop_remainder=True) # tgt_len -- unused def key_func(unused_1, unused_2, unused_3, src_len, tgt_len): """Calculate bucket_width by maximum source sequence length.""" # Pairs with length [0, bucket_width) go to bucket 0, length # [bucket_width, 2 * bucket_width) go to bucket 1, etc. Pairs with length # over ((num_bucket-1) * bucket_width) words all go into the last bucket. if src_max_len: bucket_width = (src_max_len + num_buckets - 1) // num_buckets else: bucket_width = 10 # Bucket sentence pairs by the length of their source sentence and target # sentence. bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) return tf.to_int64(tf.minimum(num_buckets, bucket_id)) def reduce_func(unused_key, windowed_data): return batching_func(windowed_data) if num_buckets > 1: batched_dataset = src_tgt_dataset.apply( tf.contrib.data.group_by_window(key_func=key_func, reduce_func=reduce_func, window_size=window_batch_size)) else: batched_dataset = batching_func(src_tgt_dataset) batched_dataset = batched_dataset.map( lambda src, tgt_in, tgt_out, source_size, tgt_in_size: ({ "source": src, "target_input": tgt_in, "target_output": tgt_out, "source_sequence_length": source_size, "target_sequence_length": tgt_in_size })) re_batched_dataset = batched_dataset.apply( tf.contrib.data.unbatch()).batch(int(per_host_batch_size), drop_remainder=True) output_devices = [ "/job:worker/replica:0/task:%d/device:CPU:0" % i for i in range(num_hosts) ] options = tf.data.Options() options.experimental_numa_aware = True options.experimental_filter_fusion = True options.experimental_map_and_filter_fusion = True re_batched_dataset = re_batched_dataset.with_options(options) multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( dataset=re_batched_dataset, devices=output_devices, max_buffer_size=10, prefetch_buffer_size=10, source_device="/job:worker/replica:0/task:0/device:CPU:0") return multi_device_iterator
sos=args.sos, eos=args.eos, unk=vocab_utils.UNK) tgt_vocab_size, tgt_vocab_file = vocab_utils.check_vocab( tgt_vocab_file, args.out_dir, check_special_token=args.check_special_token, sos=args.sos, eos=args.eos, unk=vocab_utils.UNK) #graph = tf.Graph() scope="train" #with graph.as_default(), tf.container(scope): src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(src_vocab_file, tgt_vocab_file, args.share_vocab) src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=args.batch_size, sos=args.sos, eos=args.eos, random_seed=args.random_seed, num_buckets=args.num_buckets, src_max_len=args.src_max_len,
def get_iterator(src_file, tgt_file, src_vocab_file, tgt_vocab_file, config, threads=4): output_buffer_size = config.batch_size * 1000 src_dataset = tf.contrib.data.TextLineDataset(src_file) tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) src_vocab_table, tgt_vocab_table, reverse_tgt_vocab_table = \ vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, config) src_eos_id = tf.cast( src_vocab_table.lookup(tf.constant(config.eos)), tf.int32) tgt_sos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(config.sos)), tf.int32) tgt_eos_id = tf.cast( tgt_vocab_table.lookup(tf.constant(config.eos)), tf.int32) # pair up src + tgt sentences src_tgt_dataset = tf.contrib.data.Dataset.zip((src_dataset, tgt_dataset)) # shuffle (not sure what the buffer is doing...) src_tgt_dataset = src_tgt_dataset.shuffle( output_buffer_size, config.random_seed) # break sentences into words src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values), num_threads=threads, output_buffer_size=output_buffer_size) # make sure 0 < len(seq) < max_len src_tgt_dataset = src_tgt_dataset.filter( lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)) src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:config.src_max_len], tgt[:config.tgt_max_len]), num_threads=threads, output_buffer_size=output_buffer_size) # reverse source if the user asked for it if config.reverse_src: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.reverse(src, axis=[0]), tgt), num_threads=threads, output_buffer_size=output_buffer_size) # convert word strings to ids src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), num_threads=threads, output_buffer_size=output_buffer_size) # wrap tgt examples with sos and eos src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tf.concat(([tgt_sos_id], tgt), 0), tf.concat((tgt, [tgt_eos_id]), 0)), num_threads=threads, output_buffer_size=output_buffer_size) # add in word counts. subtract one from target to avoid counting sos/eos # TODO -- TROUBLESHOOT src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in, tgt_out: ( src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_out)), num_threads=threads, output_buffer_size=output_buffer_size) # batch up def batch_pad(dataset): return dataset.padded_batch( config.batch_size, padded_shapes=(tf.TensorShape([None]), # src tf.TensorShape([None]), # tgt_input tf.TensorShape([None]), # tgt_output tf.TensorShape([]), # src_len tf.TensorShape([])), # tgt_len padding_values=(src_eos_id, tgt_eos_id, tgt_eos_id, 0, # unused 0)) # unused # bucket up if config.num_buckets > 1: # maps examples to keys (buckets) def key_func(src, tgt_in, tgt_out, src_len, tgt_len): bucket_width = (config.src_max_len + config.num_buckets - 1) // config.num_buckets bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width) return tf.to_int64(tf.minimum(config.num_buckets, bucket_id)) # batches and pads the examples for a bucket def reduce_func(unused, windowed_data): return batch_pad(windowed_data) batched_dataset = src_tgt_dataset.group_by_window( key_func=key_func, reduce_func=reduce_func, window_size=config.batch_size) else: batched_dataset = batch_pad(src_tgt_dataset) # create an iterator from this dataset batched_iter = batched_dataset.make_initializable_iterator() # pull out some values to use as "placeholder"-type things src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len = \ batched_iter.get_next() # return a NamedTuple of this stuff (as well as the initializer ) return BatchedInput( initializer=batched_iter.initializer, source=src_ids, target_input=tgt_input_ids, target_output=tgt_output_ids, source_sequence_length=src_seq_len, target_sequence_length=tgt_seq_len), \ tgt_vocab_table, reverse_tgt_vocab_table
def _input_fn(params): """Input function.""" if mode == tf.contrib.learn.ModeKeys.TRAIN: src_file = "%s.%s" % (hparams.train_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.train_prefix, hparams.tgt) else: src_file = "%s.%s" % (hparams.test_prefix, hparams.src) tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) src_vocab_file = hparams.src_vocab_file src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file) if mode == tf.contrib.learn.ModeKeys.TRAIN: if "context" in params: batch_size = params["batch_size"] global_batch_size = batch_size num_hosts = params["context"].num_hosts # TODO(dehao): update to use current_host once available in API. current_host = params["context"].current_input_fn_deployment( )[1] else: if "dataset_index" in params: current_host = params["dataset_index"] num_hosts = params["dataset_num_shards"] batch_size = params["batch_size"] global_batch_size = hparams.batch_size else: num_hosts = 1 current_host = 0 batch_size = hparams.batch_size global_batch_size = batch_size if not hparams.use_preprocessed_data: src_dataset = tf.data.TextLineDataset(src_file) tgt_dataset = tf.data.TextLineDataset(tgt_file) return iterator_utils.get_iterator( src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table, batch_size=batch_size, global_batch_size=global_batch_size, sos=hparams.sos, eos=hparams.eos, random_seed=hparams.random_seed, num_buckets=hparams.num_buckets, src_max_len=hparams.src_max_len, tgt_max_len=hparams.tgt_max_len, output_buffer_size=None, skip_count=None, num_shards=num_hosts, shard_index=current_host, reshuffle_each_iteration=True, filter_oversized_sequences=True) else: return iterator_utils.get_preprocessed_iterator( hparams.train_prefix + "*", batch_size=batch_size, random_seed=hparams.random_seed, max_seq_len=hparams.src_max_len, num_buckets=hparams.num_buckets, shard_index=current_host, num_shards=num_hosts) else: if "dataset_index" in params: current_host = params["dataset_index"] num_hosts = params["dataset_num_shards"] else: num_hosts = 1 current_host = 0 if "infer_batch_size" in params: batch_size = params["infer_batch_size"] else: batch_size = hparams.infer_batch_size src_dataset = tf.data.TextLineDataset(src_file) src_dataset = src_dataset.repeat().batch( hparams.infer_batch_size // num_hosts).shard( num_hosts, current_host).apply(tf.contrib.data.unbatch()) return iterator_utils.get_infer_iterator( src_dataset, src_vocab_table, batch_size=batch_size, eos=hparams.eos, sos=hparams.sos, src_max_len=hparams.src_max_len_infer)