def input_fn(file_path, vocab_table, batch_size, num_epochs=None, num_examples=None, seed=0, noiser=None, use_free_set=False, shuffle_input=True): vocab_table = vocab.get_vocab_lookup_tables()[vocab.STR_TO_INT] pad_token = tf.constant(bytes(PAD_TOKEN, encoding='utf8'), dtype=tf.string) pad_id = vocab_table.lookup(pad_token) base_dataset = read_examples_from_file( file_path, num_examples, seed, noiser, util.get_free_words_set() if use_free_set else None) dataset_splits = [] for index in range(len(base_dataset[0])): split_dtype = infer_dtype(base_dataset[0][index]) split = tf.data.Dataset.from_generator(generator=get_generator( base_dataset, index), output_types=(split_dtype), output_shapes=(None, )) if split_dtype == tf.string: pad = pad_token else: pad = pad_id split = split.padded_batch(batch_size, padded_shapes=[None], padding_values=pad) dataset_splits.append(split) dataset = tf.data.Dataset.zip(tuple(dataset_splits)) if num_epochs and shuffle_input: dataset = dataset.apply( tf.contrib.data.shuffle_and_repeat(500, num_epochs)) elif num_epochs: dataset = dataset.repeat(num_epochs) fake_label = tf.data.Dataset.from_tensor_slices(tf.constant([0])).repeat() dataset = dataset.zip((dataset, fake_label)) \ .prefetch(1) return dataset
def input_fn_from_gen_multi(gen, vocab_table, batch_size, shuffle_input=False, num_epochs=None, prefetch=False): vocab_table = vocab.get_vocab_lookup_tables()[vocab.STR_TO_INT] base_dataset = list(gen()) pad_id = tf.constant(vocab.SPECIAL_TOKENS.index(PAD_TOKEN), dtype=tf.int64) dataset_splits = [] for index in range(len(base_dataset[0])): split = tf.data.Dataset.from_generator(generator=get_generator( base_dataset, index), output_types=(tf.string), output_shapes=(None, )) split = split.map(lambda x: vocab_table.lookup(x)) split = split.padded_batch(batch_size, padded_shapes=[None], padding_values=(pad_id)) dataset_splits.append(split) dataset = tf.data.Dataset.zip(tuple(dataset_splits)) if num_epochs and shuffle_input: dataset = dataset.apply( tf.contrib.data.shuffle_and_repeat(500, num_epochs)) elif num_epochs: dataset = dataset.repeat(num_epochs) fake_label = tf.data.Dataset.from_tensor_slices(tf.constant([0])).repeat() dataset = dataset.zip((dataset, fake_label)) if prefetch: dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) return dataset
def input_fn_from_gen_multi(gen, vocab_table, batch_size): if isinstance(vocab_table, dict): vocab_table = vocab_table[vocab.STR_TO_INT] base_dataset = list(gen()) pad_token = tf.constant(bytes(PAD_TOKEN, encoding='utf8'), dtype=tf.string) pad_id = vocab_table.lookup(pad_token) dataset_splits = [] for index in range(len(base_dataset[0])): split_dtype = infer_dtype(base_dataset[0][index]) split = tf.data.Dataset.from_generator(generator=get_generator( base_dataset, index), output_types=(split_dtype), output_shapes=(None, )) if split_dtype == tf.string: pad = pad_token else: pad = pad_id split = split.padded_batch(batch_size, padded_shapes=[None], padding_values=pad) dataset_splits.append(split) dataset = tf.data.Dataset.zip(tuple(dataset_splits)) fake_label = tf.data.Dataset.from_tensor_slices(tf.constant([0])).repeat() dataset = dataset.zip((dataset, fake_label)) return dataset