예제 #1
0
def input_fn(file_path,
             vocab_table,
             batch_size,
             num_epochs=None,
             num_examples=None,
             seed=0,
             noiser=None,
             use_free_set=False,
             shuffle_input=True):
    vocab_table = vocab.get_vocab_lookup_tables()[vocab.STR_TO_INT]

    pad_token = tf.constant(bytes(PAD_TOKEN, encoding='utf8'), dtype=tf.string)
    pad_id = vocab_table.lookup(pad_token)

    base_dataset = read_examples_from_file(
        file_path, num_examples, seed, noiser,
        util.get_free_words_set() if use_free_set else None)

    dataset_splits = []
    for index in range(len(base_dataset[0])):
        split_dtype = infer_dtype(base_dataset[0][index])

        split = tf.data.Dataset.from_generator(generator=get_generator(
            base_dataset, index),
                                               output_types=(split_dtype),
                                               output_shapes=(None, ))

        if split_dtype == tf.string:
            pad = pad_token
        else:
            pad = pad_id

        split = split.padded_batch(batch_size,
                                   padded_shapes=[None],
                                   padding_values=pad)

        dataset_splits.append(split)

    dataset = tf.data.Dataset.zip(tuple(dataset_splits))
    if num_epochs and shuffle_input:
        dataset = dataset.apply(
            tf.contrib.data.shuffle_and_repeat(500, num_epochs))
    elif num_epochs:
        dataset = dataset.repeat(num_epochs)

    fake_label = tf.data.Dataset.from_tensor_slices(tf.constant([0])).repeat()

    dataset = dataset.zip((dataset, fake_label)) \
        .prefetch(1)

    return dataset
예제 #2
0
def input_fn_from_gen_multi(gen,
                            vocab_table,
                            batch_size,
                            shuffle_input=False,
                            num_epochs=None,
                            prefetch=False):
    vocab_table = vocab.get_vocab_lookup_tables()[vocab.STR_TO_INT]
    base_dataset = list(gen())

    pad_id = tf.constant(vocab.SPECIAL_TOKENS.index(PAD_TOKEN), dtype=tf.int64)

    dataset_splits = []
    for index in range(len(base_dataset[0])):
        split = tf.data.Dataset.from_generator(generator=get_generator(
            base_dataset, index),
                                               output_types=(tf.string),
                                               output_shapes=(None, ))
        split = split.map(lambda x: vocab_table.lookup(x))
        split = split.padded_batch(batch_size,
                                   padded_shapes=[None],
                                   padding_values=(pad_id))

        dataset_splits.append(split)

    dataset = tf.data.Dataset.zip(tuple(dataset_splits))
    if num_epochs and shuffle_input:
        dataset = dataset.apply(
            tf.contrib.data.shuffle_and_repeat(500, num_epochs))
    elif num_epochs:
        dataset = dataset.repeat(num_epochs)

    fake_label = tf.data.Dataset.from_tensor_slices(tf.constant([0])).repeat()

    dataset = dataset.zip((dataset, fake_label))
    if prefetch:
        dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)

    return dataset
예제 #3
0
def input_fn_from_gen_multi(gen, vocab_table, batch_size):
    if isinstance(vocab_table, dict):
        vocab_table = vocab_table[vocab.STR_TO_INT]

    base_dataset = list(gen())

    pad_token = tf.constant(bytes(PAD_TOKEN, encoding='utf8'), dtype=tf.string)
    pad_id = vocab_table.lookup(pad_token)

    dataset_splits = []
    for index in range(len(base_dataset[0])):
        split_dtype = infer_dtype(base_dataset[0][index])

        split = tf.data.Dataset.from_generator(generator=get_generator(
            base_dataset, index),
                                               output_types=(split_dtype),
                                               output_shapes=(None, ))

        if split_dtype == tf.string:
            pad = pad_token
        else:
            pad = pad_id

        split = split.padded_batch(batch_size,
                                   padded_shapes=[None],
                                   padding_values=pad)

        dataset_splits.append(split)

    dataset = tf.data.Dataset.zip(tuple(dataset_splits))

    fake_label = tf.data.Dataset.from_tensor_slices(tf.constant([0])).repeat()

    dataset = dataset.zip((dataset, fake_label))

    return dataset