def feed_token(model: keras.Sequential, token: int):
    # Format inputs.
    inputs = np.zeros((config.find('batch_size'), 1))
    inputs[0, 0] = token

    # Forward pass.
    probs = model.predict_on_batch(inputs)[0]
    probs = probs.squeeze()

    return probs
def ramble(model: keras.Sequential, vectorizer: Vectorizer, seed='she'):
    current_token = vectorizer.word_to_index[seed]
    current_seq = [current_token]
    model.reset_states()
    model.summary()

    probs_arr = []

    # for t in range(config.max_seq_len):
    for t in range(5):
        # Format inputs.
        inputs = np.zeros((config.find('batch_size'), 1))
        assert inputs.shape[0] == 32, inputs.shape
        inputs[0, 0] = current_token

        # Forward pass.
        probs = model.predict_on_batch(inputs)[0]
        probs = probs.squeeze()
        probs_arr.append(probs)
        assert probs.shape == (vectorizer.vocab_size,
                               ), f'{probs.shape} != {vectorizer.vocab_size}'

        # Sample prediction.
        current_token = probs.argmax()
        current_seq.append(current_token)

    text_rant = vectorizer.sequences_to_docs([[current_seq]])
    assert len(text_rant) == 1, len(text_rant)

    # Some post-processing prettification.
    text_rant = util.lmap(str.strip, text_rant[0].split('.'))
    util.print_box('Rant', list(text_rant))

    sequences = beam_search_decoder(probs_arr, k=3)
    sequences = [thing[0] for thing in sequences]
    beam_rant = vectorizer.sequences_to_docs([sequences])
    print('beam rant:', beam_rant)