def eval_input_fn(): eval_files_copy = list(eval_files) shuffle(eval_files_copy) eval_files_copy = tf.data.TFRecordDataset(eval_files_copy) dataset = DatasetSource(eval_files_copy, hparams) dataset = dataset.make_source_and_target().filter_by_max_output_length( ).repeat().group_by_batch(batch_size=1) return dataset.dataset
def eval_input_fn(): source_and_target_files = list( zip(eval_source_files, eval_target_files)) shuffle(source_and_target_files) source = tf.data.TFRecordDataset( [s for s, _ in source_and_target_files]) target = tf.data.TFRecordDataset( [t for _, t in source_and_target_files]) dataset = DatasetSource(source, target, hparams) dataset = dataset.prepare_and_zip().filter_by_max_output_length( ).repeat().group_by_batch(batch_size=1) return dataset.dataset
def train_input_fn(): train_files_copy = list(train_files) shuffle(train_files_copy) dataset = DatasetSource.create_from_tfrecord_files( train_files_copy, hparams, cycle_length=interleave_parallelism, buffer_output_elements=hparams.interleave_buffer_output_elements, prefetch_input_elements=hparams.interleave_prefetch_input_elements) batched = dataset.make_source_and_target().filter_by_max_output_length( ).shuffle_and_repeat( hparams.suffle_buffer_size).group_by_batch().prefetch( hparams.prefetch_buffer_size) return batched.dataset
def train_input_fn(): source_and_target_files = list( zip(train_source_files, train_target_files)) shuffle(source_and_target_files) source = (s for s, _ in source_and_target_files) target = (t for _, t in source_and_target_files) dataset = DatasetSource.create_from_tfrecord_files( source, target, hparams, cycle_length=interleave_parallelism, buffer_output_elements=hparams.interleave_buffer_output_elements, prefetch_input_elements=hparams.interleave_prefetch_input_elements) batched = dataset.prepare_and_zip().filter_by_max_output_length( ).repeat(count=None).shuffle( hparams.suffle_buffer_size).group_by_batch().prefetch( hparams.prefetch_buffer_size) return batched.dataset
def predict_input_fn(): source = tf.data.TFRecordDataset(list(test_source_files)) target = tf.data.TFRecordDataset(list(test_target_files)) dataset = DatasetSource(source, target, hparams) batched = dataset.prepare_and_zip().filter_by_max_output_length().group_by_batch(batch_size=1) return batched.dataset
def predict_input_fn(): records = tf.data.TFRecordDataset(list(test_files)) dataset = DatasetSource(records, hparams) batched = dataset.make_source_and_target().group_by_batch( batch_size=1).arrange_for_prediction() return batched.dataset