def feed_token(model: keras.Sequential, token: int): # Format inputs. inputs = np.zeros((config.find('batch_size'), 1)) inputs[0, 0] = token # Forward pass. probs = model.predict_on_batch(inputs)[0] probs = probs.squeeze() return probs
def ramble(model: keras.Sequential, vectorizer: Vectorizer, seed='she'): current_token = vectorizer.word_to_index[seed] current_seq = [current_token] model.reset_states() model.summary() probs_arr = [] # for t in range(config.max_seq_len): for t in range(5): # Format inputs. inputs = np.zeros((config.find('batch_size'), 1)) assert inputs.shape[0] == 32, inputs.shape inputs[0, 0] = current_token # Forward pass. probs = model.predict_on_batch(inputs)[0] probs = probs.squeeze() probs_arr.append(probs) assert probs.shape == (vectorizer.vocab_size, ), f'{probs.shape} != {vectorizer.vocab_size}' # Sample prediction. current_token = probs.argmax() current_seq.append(current_token) text_rant = vectorizer.sequences_to_docs([[current_seq]]) assert len(text_rant) == 1, len(text_rant) # Some post-processing prettification. text_rant = util.lmap(str.strip, text_rant[0].split('.')) util.print_box('Rant', list(text_rant)) sequences = beam_search_decoder(probs_arr, k=3) sequences = [thing[0] for thing in sequences] beam_rant = vectorizer.sequences_to_docs([sequences]) print('beam rant:', beam_rant)