def execute_post_inference_operations(actual_word_lists, generated_sequences, final_sequence_lengths, inverse_word_index, timestamped_file_suffix, label): try: logger.debug("Minimum generated sentence length: {}".format( min(final_sequence_lengths))) except ValueError: print("{} label has no sequences!".format(label)) return # first trims the generates sentences down to the length the decoder returns # then trim any <eos> token trimmed_generated_sequences = \ [[index for index in sequence if index != global_config.predefined_word_index[global_config.eos_token]] for sequence in [x[:(y - 1)] for (x, y) in zip(generated_sequences, final_sequence_lengths)]] generated_word_lists = \ [data_processor.generate_words_from_indices(x, inverse_word_index) for x in trimmed_generated_sequences] # Evaluate model scores bleu_scores = bleu_scorer.get_corpus_bleu_scores( [[x] for x in actual_word_lists], generated_word_lists) logger.info("bleu_scores: {}".format(bleu_scores)) generated_sentences = [" ".join(x) for x in generated_word_lists] output_file_path = "output/{}-inference/generated_sentences_{}.txt".format( timestamped_file_suffix, label) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) with open(output_file_path, 'w') as output_file: for sentence in generated_sentences: output_file.write(sentence + "\n") actual_sentences = [" ".join(x) for x in actual_word_lists] output_file_path = "output/{}-inference/actual_sentences_{}.txt".format( timestamped_file_suffix, label) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) with open(output_file_path, 'w') as output_file: for sentence in actual_sentences: output_file.write(sentence + "\n")
def main(argv): options = Options() parser = argparse.ArgumentParser() parser.add_argument("--logging-level", type=str, default="INFO") run_mode = parser.add_mutually_exclusive_group(required=True) run_mode.add_argument("--train-model", action="store_true", default=False) run_mode.add_argument("--transform-text", action="store_true", default=False) run_mode.add_argument("--generate-novel-text", action="store_true", default=False) parser.parse_known_args(args=argv, namespace=options) if options.train_model: parser.add_argument("--vocab-size", type=int, default=1000) parser.add_argument("--training-epochs", type=int, default=10) parser.add_argument("--text-file-path", type=str, required=True) parser.add_argument("--label-file-path", type=str, required=True) parser.add_argument("--validation-text-file-path", type=str, required=True) parser.add_argument("--validation-label-file-path", type=str, required=True) parser.add_argument("--training-embeddings-file-path", type=str) parser.add_argument("--validation-embeddings-file-path", type=str, required=True) parser.add_argument("--dump-embeddings", action="store_true", default=False) parser.add_argument("--classifier-saved-model-path", type=str, required=True) if options.transform_text: parser.add_argument("--saved-model-path", type=str, required=True) parser.add_argument("--evaluation-text-file-path", type=str, required=True) parser.add_argument("--evaluation-label-file-path", type=str, required=True) if options.generate_novel_text: parser.add_argument("--saved-model-path", type=str, required=True) parser.add_argument("--num-sentences-to-generate", type=int, default=1000, required=True) parser.add_argument("--label-index", type=int, default=1000, required=False) parser.parse_known_args(args=argv, namespace=options) global logger logger = log_initializer.setup_custom_logger(global_config.logger_name, options.logging_level) if not (options.train_model or options.transform_text or options.generate_novel_text): logger.info("Nothing to do. Exiting ...") sys.exit(0) global_config.training_epochs = options.training_epochs logger.info("experiment_timestamp: {}".format( global_config.experiment_timestamp)) # Train and save model if options.train_model: os.makedirs(global_config.save_directory) with open(global_config.model_config_file_path, 'w') as model_config_file: json.dump(obj=mconf.__dict__, fp=model_config_file, indent=4) logger.info("Saved model config to {}".format( global_config.model_config_file_path)) # Retrieve all data logger.info("Reading data ...") [ word_index, padded_sequences, text_sequence_lengths, one_hot_labels, num_labels, text_tokenizer, inverse_word_index ] = get_data(options) data_size = padded_sequences.shape[0] encoder_embedding_matrix, decoder_embedding_matrix = \ get_word_embeddings(options.training_embeddings_file_path, word_index) # Build model logger.info("Building model architecture ...") network = adversarial_autoencoder.AdversarialAutoencoder() network.build_model(word_index, encoder_embedding_matrix, decoder_embedding_matrix, num_labels) logger.info("Training model ...") sess = tf_session_helper.get_tensorflow_session() [_, validation_actual_word_lists, validation_sequences, validation_sequence_lengths] = \ data_processor.get_test_sequences( options.validation_text_file_path, text_tokenizer, word_index, inverse_word_index) [_, validation_labels] = \ data_processor.get_test_labels(options.validation_label_file_path, global_config.save_directory) network.train(sess, data_size, padded_sequences, text_sequence_lengths, one_hot_labels, num_labels, word_index, encoder_embedding_matrix, decoder_embedding_matrix, validation_sequences, validation_sequence_lengths, validation_labels, inverse_word_index, validation_actual_word_lists, options) sess.close() logger.info("Training complete!") elif options.transform_text: # Enforce a particular style embedding and regenerate text logger.info("Transforming text style ...") with open( os.path.join(options.saved_model_path, global_config.model_config_file), 'r') as json_file: model_config_dict = json.load(json_file) mconf.init_from_dict(model_config_dict) logger.info("Restored model config from saved JSON") with open( os.path.join(options.saved_model_path, global_config.vocab_save_file), 'r') as json_file: word_index = json.load(json_file) with open( os.path.join(options.saved_model_path, global_config.index_to_label_dict_file), 'r') as json_file: index_to_label_map = json.load(json_file) with open( os.path.join(options.saved_model_path, global_config.average_label_embeddings_file), 'rb') as pickle_file: average_label_embeddings = pickle.load(pickle_file) global_config.vocab_size = len(word_index) num_labels = len(index_to_label_map) text_tokenizer = tf.keras.preprocessing.text.Tokenizer( num_words=global_config.vocab_size, filters=global_config.tokenizer_filters) text_tokenizer.word_index = word_index inverse_word_index = {v: k for k, v in word_index.items()} [actual_sequences, _, padded_sequences, text_sequence_lengths] = \ data_processor.get_test_sequences( options.evaluation_text_file_path, text_tokenizer, word_index, inverse_word_index) [label_sequences, _] = \ data_processor.get_test_labels(options.evaluation_label_file_path, options.saved_model_path) logger.info("Building model architecture ...") network = adversarial_autoencoder.AdversarialAutoencoder() encoder_embedding_matrix, decoder_embedding_matrix = get_word_embeddings( None, word_index) network.build_model(word_index, encoder_embedding_matrix, decoder_embedding_matrix, num_labels) sess = tf_session_helper.get_tensorflow_session() total_nll = 0 for i in range(num_labels): logger.info("Style chosen: {}".format(i)) filtered_actual_sequences = list() filtered_padded_sequences = list() filtered_text_sequence_lengths = list() for k in range(len(actual_sequences)): if label_sequences[k] != i: filtered_actual_sequences.append(actual_sequences[k]) filtered_padded_sequences.append(padded_sequences[k]) filtered_text_sequence_lengths.append( text_sequence_lengths[k]) style_embedding = np.asarray(average_label_embeddings[i]) [generated_sequences, final_sequence_lengths, _, _, _, cross_entropy_scores] = \ network.transform_sentences( sess, filtered_padded_sequences, filtered_text_sequence_lengths, style_embedding, num_labels, os.path.join(options.saved_model_path, global_config.model_save_file)) nll = -np.mean(a=cross_entropy_scores, axis=0) total_nll += nll logger.info("NLL: {}".format(nll)) actual_word_lists = \ [data_processor.generate_words_from_indices(x, inverse_word_index) for x in filtered_actual_sequences] execute_post_inference_operations( actual_word_lists, generated_sequences, final_sequence_lengths, inverse_word_index, global_config.experiment_timestamp, i) logger.info("Generation complete for label {}".format(i)) logger.info("Mean NLL: {}".format(total_nll / num_labels)) logger.info("Predicting labels from latent spaces ...") _, _, overall_label_predictions, style_label_predictions, adversarial_label_predictions, _ = \ network.transform_sentences( sess, padded_sequences, text_sequence_lengths, average_label_embeddings[0], num_labels, os.path.join(options.saved_model_path, global_config.model_save_file)) # write label predictions to file output_file_path = "output/{}-inference/overall_labels_prediction.txt".format( global_config.experiment_timestamp) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) with open(output_file_path, 'w') as output_file: for one_hot_label in overall_label_predictions: output_file.write("{}\n".format( one_hot_label.tolist().index(1))) output_file_path = "output/{}-inference/style_labels_prediction.txt".format( global_config.experiment_timestamp) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) with open(output_file_path, 'w') as output_file: for one_hot_label in style_label_predictions: output_file.write("{}\n".format( one_hot_label.tolist().index(1))) output_file_path = "output/{}-inference/adversarial_labels_prediction.txt".format( global_config.experiment_timestamp) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) with open(output_file_path, 'w') as output_file: for one_hot_label in adversarial_label_predictions: output_file.write("{}\n".format( one_hot_label.tolist().index(1))) logger.info("Inference run complete") sess.close() elif options.generate_novel_text: logger.info("Generating novel text") with open( os.path.join(options.saved_model_path, global_config.model_config_file), 'r') as json_file: model_config_dict = json.load(json_file) mconf.init_from_dict(model_config_dict) logger.info("Restored model config from saved JSON") with open( os.path.join(options.saved_model_path, global_config.vocab_save_file), 'r') as json_file: word_index = json.load(json_file) with open( os.path.join(options.saved_model_path, global_config.index_to_label_dict_file), 'r') as json_file: index_to_label_map = json.load(json_file) with open( os.path.join(options.saved_model_path, global_config.average_label_embeddings_file), 'rb') as pickle_file: average_label_embeddings = pickle.load(pickle_file) global_config.vocab_size = len(word_index) inverse_word_index = {v: k for k, v in word_index.items()} num_labels = len(index_to_label_map) text_tokenizer = tf.keras.preprocessing.text.Tokenizer( num_words=global_config.vocab_size, filters=global_config.tokenizer_filters) text_tokenizer.word_index = word_index data_processor.populate_word_blacklist(word_index) logger.info("Building model architecture ...") network = adversarial_autoencoder.AdversarialAutoencoder() encoder_embedding_matrix, decoder_embedding_matrix = get_word_embeddings( None, word_index) network.build_model(word_index, encoder_embedding_matrix, decoder_embedding_matrix, num_labels) sess = tf_session_helper.get_tensorflow_session() for label_index in index_to_label_map: if options.label_index and label_index != options.label_index: continue style_embedding = np.asarray( average_label_embeddings[int(label_index)]) generated_sequences, final_sequence_lengths = \ network.generate_novel_sentences( sess, style_embedding, options.num_sentences_to_generate, num_labels, os.path.join(options.saved_model_path, global_config.model_save_file)) # first trims the generates sentences down to the length the decoder returns # then trim any <eos> token trimmed_generated_sequences = \ [[index for index in sequence if index != global_config.predefined_word_index[global_config.eos_token]] for sequence in [x[:(y - 1)] for (x, y) in zip(generated_sequences, final_sequence_lengths)]] generated_word_lists = \ [data_processor.generate_words_from_indices(x, inverse_word_index) for x in trimmed_generated_sequences] generated_sentences = [" ".join(x) for x in generated_word_lists] output_file_path = "output/{}-generation/generated_sentences_{}.txt".format( global_config.experiment_timestamp, label_index) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) with open(output_file_path, 'w') as output_file: for sentence in generated_sentences: output_file.write(sentence + "\n") logger.info("Generated {} sentences of label {} at path {}".format( options.num_sentences_to_generate, index_to_label_map[label_index], output_file_path)) sess.close() logger.info("Generation run complete")
def run_validation(self, options, num_labels, validation_sequences, validation_sequence_lengths, validation_labels, validation_actual_word_lists, all_style_embeddings, shuffled_one_hot_labels, inverse_word_index, current_epoch, sess): logger.info("Running Validation {}:".format( current_epoch // global_config.validation_interval)) glove_model = content_preservation.load_glove_model( options.validation_embeddings_file_path) validation_style_transfer_scores = list() validation_content_preservation_scores = list() validation_word_overlap_scores = list() for i in range(num_labels): logger.info("validating label {}".format(i)) label_embeddings = list() validation_sequences_to_transfer = list() validation_labels_to_transfer = list() validation_sequence_lengths_to_transfer = list() for k in range(len(all_style_embeddings)): if shuffled_one_hot_labels[k].tolist().index(1) == i: label_embeddings.append(all_style_embeddings[k]) for k in range(len(validation_sequences)): if validation_labels[k].tolist().index(1) != i: validation_sequences_to_transfer.append( validation_sequences[k]) validation_labels_to_transfer.append(validation_labels[k]) validation_sequence_lengths_to_transfer.append( validation_sequence_lengths[k]) style_embedding = np.mean(np.asarray(label_embeddings), axis=0) validation_batches = len( validation_sequences_to_transfer) // mconf.batch_size if len(validation_sequences_to_transfer) % mconf.batch_size: validation_batches += 1 validation_generated_sequences = list() validation_generated_sequence_lengths = list() for val_batch_number in range(validation_batches): (start_index, end_index) = self.get_batch_indices( batch_number=val_batch_number, data_limit=len(validation_sequences_to_transfer)) conditioning_embedding = np.tile(A=style_embedding, reps=(end_index - start_index, 1)) [validation_generated_sequences_batch, validation_sequence_lengths_batch] = \ self.run_batch( sess, start_index, end_index, [self.inference_output, self.final_sequence_lengths], validation_sequences_to_transfer, validation_labels_to_transfer, validation_sequence_lengths_to_transfer, conditioning_embedding, True, False, 0, 0, current_epoch) validation_generated_sequences.extend( validation_generated_sequences_batch) validation_generated_sequence_lengths.extend( validation_sequence_lengths_batch) trimmed_generated_sequences = \ [[index for index in sequence if index != global_config.predefined_word_index[global_config.eos_token]] for sequence in [x[:(y - 1)] for (x, y) in zip( validation_generated_sequences, validation_generated_sequence_lengths)]] generated_word_lists = \ [data_processor.generate_words_from_indices(x, inverse_word_index) for x in trimmed_generated_sequences] generated_sentences = [" ".join(x) for x in generated_word_lists] output_file_path = "output/{}-training/validation_sentences_{}.txt".format( global_config.experiment_timestamp, i) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) with open(output_file_path, 'w') as output_file: for sentence in generated_sentences: output_file.write(sentence + "\n") [style_transfer_score, confusion_matrix] = style_transfer.get_style_transfer_score( options.classifier_saved_model_path, output_file_path, str(i), None) logger.debug( "style_transfer_score: {}".format(style_transfer_score)) logger.debug("confusion_matrix:\n{}".format(confusion_matrix)) content_preservation_score = content_preservation.get_content_preservation_score( validation_actual_word_lists, generated_word_lists, glove_model) logger.debug("content_preservation_score: {}".format( content_preservation_score)) word_overlap_score = content_preservation.get_word_overlap_score( validation_actual_word_lists, generated_word_lists) logger.debug("word_overlap_score: {}".format(word_overlap_score)) validation_style_transfer_scores.append(style_transfer_score) validation_content_preservation_scores.append( content_preservation_score) validation_word_overlap_scores.append(word_overlap_score) aggregate_style_transfer = np.mean( np.asarray(validation_style_transfer_scores)) logger.info( "Aggregate Style Transfer: {}".format(aggregate_style_transfer)) aggregate_content_preservation = np.mean( np.asarray(validation_content_preservation_scores)) logger.info("Aggregate Content Preservation: {}".format( aggregate_content_preservation)) aggregate_word_overlap = np.mean( np.asarray(validation_word_overlap_scores)) logger.info( "Aggregate Word Overlap: {}".format(aggregate_word_overlap)) with open(global_config.validation_scores_path, 'a+') as validation_scores_file: validation_record = { "epoch": current_epoch, "style-transfer": aggregate_style_transfer, "content-preservation": aggregate_content_preservation, "word-overlap": aggregate_word_overlap } validation_scores_file.write(json.dumps(validation_record) + "\n")