示例#1
0
def construct_lstm_model(maxlen=40, initial_weights=None, layers=1):
    """Construct a 2-layer LSTM model on the sonnets dataset.

    Optionally, load from saved weights by passing a filename to initial_weights.

    """
    sonnets = load_sonnets()
    chars = sorted(set([c for s in sonnets for l in s for c in l]))

    model = Sequential()
    model.add(
        LSTM(256,
             input_shape=(maxlen, len(chars)),
             return_sequences=layers > 1))
    for i in range(1, layers):
        model.add(LSTM(256, return_sequences=i < (layers - 1)))
    model.add(Dense(len(chars)))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    if initial_weights:
        model.load_weights(initial_weights)

    return model
示例#2
0
def generate_sonnet(model, maxlen=40, sonnets=None, chars=None):
    """Generate a sonnet with the given model.

    Uses the trained model to predict a new sonnet and outputs it to stdout.

    """
    sonnets = sonnets or load_sonnets()
    chars = chars or sorted(set([c for s in sonnets for l in s for c in l]))
    char_indices = dict((c, i) for i, c in enumerate(chars))
    indices_char = dict((i, c) for i, c in enumerate(chars))
    text = ''.join([c for s in sonnets for l in s for c in l])
    start_index = random.randint(0, len(text) - maxlen - 1)
    for diversity in [0.25, 0.75, 1.5]:
        print('----- temperature:', diversity)

        generated = ''
        sentence = "shall i compare thee to a summer's day?\n"
        generated += sentence
        print('----- Generating with seed: "' + sentence + '"')
        sys.stdout.write(generated)

        for i in range(670):
            x_pred = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(sentence):
                x_pred[0, t, char_indices[char]] = 1.

            preds = model.predict(x_pred, verbose=0)[0]
            next_index = _sample(preds, diversity)
            next_char = indices_char[next_index]

            generated += next_char
            sentence = sentence[1:] + next_char

            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()
示例#3
0
def get_word_mapped_lines(sonnets=None, words=None, line_model=False):
    """Returns sonnets formatted as word indexes rather than character data.

    Given sonnets presented as a lists of lines of character data, parse the words out and map them
    to integer indices using the provided word lists.

    If sonnets and words are not provided, they will be loaded from the default files.

    If line_model is True, the model is trained to generate single lines instead of entire sonnets.

    Returns: 
        sonnets, a 2D numpy array of sonnet lines, where each element is either a word index or an
            end of line marker, eol_marker.
        word_indicies, a dictionary from character data for a word to its index in words.
        eol_marker, the special integer that is used to indicate the end of a line

    """
    sonnets = sonnets or load_sonnets()
    if not words:
        words, _, _, _, _ = load_syllables()
    # add a word for newlines so that the HMM can learn to do when a line break
    words.append('\n')
    word_indices = {c: i for i, c in enumerate(words)}
    eol_marker = word_indices['\n']
    X = []
    for s in sonnets:
        smapped = []
        for l in s:
            if line_model:
                smapped = []
            syl_count = 0
            ws = l.split(' ')
            for w in ws:
                if w not in words:
                    w = w.rstrip("!'(),.:;?\n")
                    if w not in words:
                        w = w.strip("!'(),.:;?\n")
                wi = word_indices[w]
                smapped.append(wi)
            if line_model:
                X.append(smapped)
            else:
                smapped.append(eol_marker)
        if not line_model:
            X.append(smapped)

    # hack attack: add some words that are never seen to the last line of input to make it a full
    # multinomial so that hmmlearn doesn't break
    all_words = set(list(range(len(words) + 1)))
    found_words = set()
    for ln in X:
        found_words = found_words.union(ln)
    missing_words = all_words.difference(found_words)
    X[-1].extend(list(missing_words))

    # make the list not ragged, with EOL markers
    max_words = max([len(s) for s in X])
    for i in range(len(X)):
        X[i] += [eol_marker] * (max_words + 1 - len(X[i]))

    X = np.array(X)

    return X, word_indices, eol_marker
示例#4
0
MODEL_TYPES = [
    'line',  # generate a line at a time, independently
    'sonnet',  # generate whole sonnets
    'stress',  # generate a line at a time, with stress encoded
]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        type=str,
                        choices=MODEL_TYPES,
                        required=True)
    args = parser.parse_args()

    sonnets = load_sonnets()
    words, _, word_syllables, _, _ = load_syllables()
    if args.model == 'stress':
        encoded_sonnets = stress_repr(sonnets, word_syllables)
        X, word_indices, eol_marker = get_stress_mapped_lines(encoded_sonnets,
                                                              words=words)
        word_lookup = lambda i: words[i >> 1]
    else:
        X, word_indices, eol_marker = get_word_mapped_lines(
            words=words, line_model=args.model == 'line')
        word_lookup = lambda i: words[i]
    model = train_hmm(X)
    sonnet = generate_sonnet(model,
                             eol_marker,
                             line_model=args.model != 'line')
示例#5
0
def train_lstm(model,
               initial_epoch=0,
               epochs=600,
               maxlen=40,
               file_suffix=None):
    def _on_epoch_end(epoch, logs):
        # Function invoked at end of each epoch. Prints generated text.
        print()
        print('----- Generating text after Epoch: %d' % epoch)

        start_index = random.randint(0, len(text) - maxlen - 1)
        for diversity in [0.2, 0.5, 1.0, 1.2]:
            print('----- diversity:', diversity)

            generated = ''
            sentence = text[start_index:start_index + maxlen]
            generated += sentence
            print('----- Generating with seed: "' + sentence + '"')
            sys.stdout.write(generated)

            for i in range(400):
                x_pred = np.zeros((1, maxlen, len(chars)))
                for t, char in enumerate(sentence):
                    x_pred[0, t, char_indices[char]] = 1.

                preds = model.predict(x_pred, verbose=0)[0]
                next_index = sample(preds, diversity)
                next_char = indices_char[next_index]

                generated += next_char
                sentence = sentence[1:] + next_char

                sys.stdout.write(next_char)
                sys.stdout.flush()
            print()

    sonnets = load_sonnets()
    chars = sorted(set([c for s in sonnets for l in s for c in l]))
    text = ''.join([c for s in sonnets for l in s for c in l])
    char_indices = dict((c, i) for i, c in enumerate(chars))
    indices_char = dict((i, c) for i, c in enumerate(chars))

    step = 1
    sentences = []
    next_chars = []
    for i in range(0, len(text) - maxlen, step):
        sentences.append(text[i:i + maxlen])
        next_chars.append(text[i + maxlen])
    print('nb sequences:', len(sentences))

    x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
    y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
    for i, sentence in enumerate(sentences):
        for t, char in enumerate(sentence):
            x[i, t, char_indices[char]] = 1
        y[i, char_indices[next_chars[i]]] = 1

    print_callback = LambdaCallback(on_epoch_end=_on_epoch_end)
    file_suffix = file_suffix or ''
    filepath = 'weights-improvement-{epoch:02d}-{loss:.4f}%s.hdf5' % file_suffix
    checkpoint = ModelCheckpoint(filepath,
                                 monitor='loss',
                                 verbose=1,
                                 save_best_only=True,
                                 mode='min')

    print(x.shape)
    model.fit(x,
              y,
              batch_size=128,
              epochs=epochs,
              callbacks=[checkpoint],
              initial_epoch=initial_epoch)