コード例 #1
0
def build_dataset_verse(text, vocab, idx2syl, syl2idx, seq_length):

    #    step_length = 8
    step_length = seq_length + 1

    text_in_syls_verse_file = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'text_in_syls_verse.json')

    if os.path.isfile(text_in_syls_verse_file):
        syls_verse_list = load_syls_list(text_in_syls_verse_file)
    else:
        syls_verse_list = text_in_rev_syls(text)
        save_syls_list(syls_verse_list, text_in_syls_verse_file)

    # syls_verse_list = text_in_rev_syls(text)

    text_as_int = np.array([syl2idx[s] for s in syls_verse_list])

    dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
    dataset = dataset.window(seq_length + 1,
                             shift=step_length,
                             stride=1,
                             drop_remainder=True)
    dataset = dataset.flat_map(lambda window: window.batch(seq_length + 1))

    def split_input_target(chunk):
        input_text = chunk[:-1]
        target_text = chunk[1:]
        return input_text, target_text

    dataset = dataset.map(split_input_target)

    dataset = dataset.shuffle(1000)

    return dataset
コード例 #2
0
def build_vocab_rhyme(text):

    text_in_syls_rhyme_file = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'text_in_syls_rhyme.json')

    # if os.path.isfile(text_in_syls_rhyme_file):
    #     syls_rhyme_list = load_syls_list(text_in_syls_rhyme_file)
    # else:
    #     syls_rhyme_list = text_in_syls_rhyme(text)
    #     save_syls_list(syls_rhyme_list, text_in_syls_rhyme_file)

    # vocab = sorted(list(set(syls_rhyme_list)))

    syls_rhyme_list = text_in_syls_rhyme(text)
    save_syls_list(syls_rhyme_list, text_in_syls_rhyme_file)

    vocab = sorted(list(set(syls_rhyme_list)))

    idx2syl = {i: s for (i, s) in enumerate(vocab)}
    syl2idx = {s: i for (i, s) in enumerate(vocab)}

    return vocab, idx2syl, syl2idx
コード例 #3
0
os.makedirs(os.path.join(logs_dir, model_filename), exist_ok=True)

output_file = os.path.join(logs_dir, model_filename, "output.txt")
raw_output_file = os.path.join(logs_dir, model_filename, "raw_output.txt")

output_toned_file = os.path.join(logs_dir, model_filename, "output_toned.txt")
raw_output_toned_file = os.path.join(logs_dir, model_filename,
                                     "raw_output_toned.txt")

text_in_syls_rhyme_file = os.path.join(working_dir, 'text_in_syls_rhyme.json')

if os.path.isfile(text_in_syls_rhyme_file):
    syls_rhyme_list = load_syls_list(text_in_syls_rhyme_file)
else:
    syls_rhyme_list = text_in_syls_rhyme(divine_comedy)
    save_syls_list(syls_rhyme_list, text_in_syls_rhyme_file)

# syls_rhyme_list = text_in_syls_rhyme(divine_comedy)

#index_eoc = syls_rhyme_list.index(special_tokens['END_OF_CANTO']) + 1
indexes = [
    i for i, x in enumerate(syls_rhyme_list)
    if x == special_tokens['END_OF_CANTO'] and i > SEQ_LENGTH_RHYME
]
index_eoc = np.random.choice(indexes) + 1
start_idx = max(0, index_eoc - SEQ_LENGTH_RHYME)
start_seq_rhyme = syls_rhyme_list[start_idx:index_eoc]

text_in_syls_verse_file = os.path.join(working_dir, 'text_in_syls_verse.json')

if os.path.isfile(text_in_syls_verse_file):