from models.seq2seq_model import Seq2SeqModel from models.seq2seq_model import Seq2SeqModelAttention from configuration import get_configuration from utils import Preprocessing # Initialize the model config = get_configuration() preprocess = Preprocessing(config=config) model = Seq2SeqModelAttention(config) checkpoint_file = 'runs/baseline-cornell-twitter-attn-dropout/model-18000' # Launch chat interface print( "*** Hi there. Ask me a question. I will try my best to reply to you with something intelligible.\ If you think that is not happening, enter \"q\" and quit ***") query = input(">") while query != "q": # Tokenize the query preprocess.initialize_vocabulary() token_ids = preprocess.sentence_to_token_ids(query) # Reverse the token ids and feed into the RNN reverse_token_ids = [list(reversed(token_ids))] output_tokens = model.infer(checkpoint_file, reverse_token_ids, verbose=False) # Convert token ids back to words and print to output output = preprocess.token_ids_to_sentence(output_tokens) print(output[0]) query = input(">")
open((config.data_dir) + '/input_test_triples.pkl', 'rb')) train_batches = generate_batches(train_data, batch_size=config.batch_size, num_epochs=config.n_epochs) eval_batches = generate_batches(eval_data, batch_size=config.batch_size, num_epochs=1) # model.train(train_batches, eval_batches, verbose=True) # # Evaluate perplexity of trained model on validation data # model_dir = 'runs/1496648182' # model_file = '/model-12000' # # print(model.evaluate(eval_batches, model_dir=model_dir, model_file=model_file)) # # Infer outputs on a subset of test data model_dir = 'runs/149664818' model_file = '/model-12000' checkpoint_file = model_dir + model_file test_data = pickle.load(open((config.data_dir) + '/input_test.pkl', 'rb'))[0][:10] predicted_outputs = model.infer(checkpoint_file, test_data) preprocess.initialize_vocabulary() # Reverse back the test sentences test_data = [list(reversed(sentence)) for sentence in test_data] messages = preprocess.token_ids_to_sentence(test_data) responses = preprocess.token_ids_to_sentence(predicted_outputs) # Print to file with open(model_dir + "/model-12000-conversations", 'w') as f: for idx, message in enumerate(messages): f.write(message + " ====> " + responses[idx] + "\n")