def execute_post_inference_operations(actual_word_lists, generated_sequences,
                                      final_sequence_lengths,
                                      inverse_word_index,
                                      timestamped_file_suffix, label):
    try:
        logger.debug("Minimum generated sentence length: {}".format(
            min(final_sequence_lengths)))
    except ValueError:
        print("{} label has no sequences!".format(label))
        return

    # first trims the generates sentences down to the length the decoder returns
    # then trim any <eos> token
    trimmed_generated_sequences = \
        [[index for index in sequence
          if index != global_config.predefined_word_index[global_config.eos_token]]
         for sequence in [x[:(y - 1)] for (x, y) in zip(generated_sequences, final_sequence_lengths)]]

    generated_word_lists = \
        [data_processor.generate_words_from_indices(x, inverse_word_index)
         for x in trimmed_generated_sequences]

    # Evaluate model scores
    bleu_scores = bleu_scorer.get_corpus_bleu_scores(
        [[x] for x in actual_word_lists], generated_word_lists)
    logger.info("bleu_scores: {}".format(bleu_scores))

    generated_sentences = [" ".join(x) for x in generated_word_lists]
    output_file_path = "output/{}-inference/generated_sentences_{}.txt".format(
        timestamped_file_suffix, label)
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
    with open(output_file_path, 'w') as output_file:
        for sentence in generated_sentences:
            output_file.write(sentence + "\n")

    actual_sentences = [" ".join(x) for x in actual_word_lists]
    output_file_path = "output/{}-inference/actual_sentences_{}.txt".format(
        timestamped_file_suffix, label)
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
    with open(output_file_path, 'w') as output_file:
        for sentence in actual_sentences:
            output_file.write(sentence + "\n")
def main(argv):
    options = Options()

    parser = argparse.ArgumentParser()
    parser.add_argument("--logging-level", type=str, default="INFO")
    run_mode = parser.add_mutually_exclusive_group(required=True)
    run_mode.add_argument("--train-model", action="store_true", default=False)
    run_mode.add_argument("--transform-text",
                          action="store_true",
                          default=False)
    run_mode.add_argument("--generate-novel-text",
                          action="store_true",
                          default=False)

    parser.parse_known_args(args=argv, namespace=options)
    if options.train_model:
        parser.add_argument("--vocab-size", type=int, default=1000)
        parser.add_argument("--training-epochs", type=int, default=10)
        parser.add_argument("--text-file-path", type=str, required=True)
        parser.add_argument("--label-file-path", type=str, required=True)
        parser.add_argument("--validation-text-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--validation-label-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--training-embeddings-file-path", type=str)
        parser.add_argument("--validation-embeddings-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--dump-embeddings",
                            action="store_true",
                            default=False)
        parser.add_argument("--classifier-saved-model-path",
                            type=str,
                            required=True)
    if options.transform_text:
        parser.add_argument("--saved-model-path", type=str, required=True)
        parser.add_argument("--evaluation-text-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--evaluation-label-file-path",
                            type=str,
                            required=True)
    if options.generate_novel_text:
        parser.add_argument("--saved-model-path", type=str, required=True)
        parser.add_argument("--num-sentences-to-generate",
                            type=int,
                            default=1000,
                            required=True)
        parser.add_argument("--label-index",
                            type=int,
                            default=1000,
                            required=False)

    parser.parse_known_args(args=argv, namespace=options)

    global logger
    logger = log_initializer.setup_custom_logger(global_config.logger_name,
                                                 options.logging_level)

    if not (options.train_model or options.transform_text
            or options.generate_novel_text):
        logger.info("Nothing to do. Exiting ...")
        sys.exit(0)

    global_config.training_epochs = options.training_epochs
    logger.info("experiment_timestamp: {}".format(
        global_config.experiment_timestamp))

    # Train and save model
    if options.train_model:
        os.makedirs(global_config.save_directory)
        with open(global_config.model_config_file_path,
                  'w') as model_config_file:
            json.dump(obj=mconf.__dict__, fp=model_config_file, indent=4)
        logger.info("Saved model config to {}".format(
            global_config.model_config_file_path))

        # Retrieve all data
        logger.info("Reading data ...")
        [
            word_index, padded_sequences, text_sequence_lengths,
            one_hot_labels, num_labels, text_tokenizer, inverse_word_index
        ] = get_data(options)
        data_size = padded_sequences.shape[0]

        encoder_embedding_matrix, decoder_embedding_matrix = \
            get_word_embeddings(options.training_embeddings_file_path, word_index)

        # Build model
        logger.info("Building model architecture ...")
        network = adversarial_autoencoder.AdversarialAutoencoder()
        network.build_model(word_index, encoder_embedding_matrix,
                            decoder_embedding_matrix, num_labels)

        logger.info("Training model ...")
        sess = tf_session_helper.get_tensorflow_session()

        [_, validation_actual_word_lists, validation_sequences, validation_sequence_lengths] = \
            data_processor.get_test_sequences(
                options.validation_text_file_path, text_tokenizer, word_index, inverse_word_index)
        [_, validation_labels] = \
            data_processor.get_test_labels(options.validation_label_file_path, global_config.save_directory)

        network.train(sess, data_size, padded_sequences, text_sequence_lengths,
                      one_hot_labels, num_labels, word_index,
                      encoder_embedding_matrix, decoder_embedding_matrix,
                      validation_sequences, validation_sequence_lengths,
                      validation_labels, inverse_word_index,
                      validation_actual_word_lists, options)
        sess.close()

        logger.info("Training complete!")

    elif options.transform_text:
        # Enforce a particular style embedding and regenerate text
        logger.info("Transforming text style ...")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.model_config_file),
                'r') as json_file:
            model_config_dict = json.load(json_file)
            mconf.init_from_dict(model_config_dict)
            logger.info("Restored model config from saved JSON")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.vocab_save_file), 'r') as json_file:
            word_index = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.index_to_label_dict_file),
                'r') as json_file:
            index_to_label_map = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.average_label_embeddings_file),
                'rb') as pickle_file:
            average_label_embeddings = pickle.load(pickle_file)

        global_config.vocab_size = len(word_index)

        num_labels = len(index_to_label_map)
        text_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=global_config.vocab_size,
            filters=global_config.tokenizer_filters)
        text_tokenizer.word_index = word_index

        inverse_word_index = {v: k for k, v in word_index.items()}
        [actual_sequences, _, padded_sequences, text_sequence_lengths] = \
            data_processor.get_test_sequences(
                options.evaluation_text_file_path, text_tokenizer, word_index, inverse_word_index)
        [label_sequences, _] = \
            data_processor.get_test_labels(options.evaluation_label_file_path, options.saved_model_path)

        logger.info("Building model architecture ...")
        network = adversarial_autoencoder.AdversarialAutoencoder()
        encoder_embedding_matrix, decoder_embedding_matrix = get_word_embeddings(
            None, word_index)
        network.build_model(word_index, encoder_embedding_matrix,
                            decoder_embedding_matrix, num_labels)

        sess = tf_session_helper.get_tensorflow_session()

        total_nll = 0
        for i in range(num_labels):
            logger.info("Style chosen: {}".format(i))

            filtered_actual_sequences = list()
            filtered_padded_sequences = list()
            filtered_text_sequence_lengths = list()
            for k in range(len(actual_sequences)):
                if label_sequences[k] != i:
                    filtered_actual_sequences.append(actual_sequences[k])
                    filtered_padded_sequences.append(padded_sequences[k])
                    filtered_text_sequence_lengths.append(
                        text_sequence_lengths[k])

            style_embedding = np.asarray(average_label_embeddings[i])
            [generated_sequences, final_sequence_lengths, _, _, _, cross_entropy_scores] = \
                network.transform_sentences(
                    sess, filtered_padded_sequences, filtered_text_sequence_lengths, style_embedding,
                    num_labels, os.path.join(options.saved_model_path, global_config.model_save_file))
            nll = -np.mean(a=cross_entropy_scores, axis=0)
            total_nll += nll
            logger.info("NLL: {}".format(nll))

            actual_word_lists = \
                [data_processor.generate_words_from_indices(x, inverse_word_index)
                 for x in filtered_actual_sequences]

            execute_post_inference_operations(
                actual_word_lists, generated_sequences, final_sequence_lengths,
                inverse_word_index, global_config.experiment_timestamp, i)

            logger.info("Generation complete for label {}".format(i))

        logger.info("Mean NLL: {}".format(total_nll / num_labels))

        logger.info("Predicting labels from latent spaces ...")
        _, _, overall_label_predictions, style_label_predictions, adversarial_label_predictions, _ = \
            network.transform_sentences(
                sess, padded_sequences, text_sequence_lengths, average_label_embeddings[0],
                num_labels, os.path.join(options.saved_model_path, global_config.model_save_file))

        # write label predictions to file
        output_file_path = "output/{}-inference/overall_labels_prediction.txt".format(
            global_config.experiment_timestamp)
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
        with open(output_file_path, 'w') as output_file:
            for one_hot_label in overall_label_predictions:
                output_file.write("{}\n".format(
                    one_hot_label.tolist().index(1)))

        output_file_path = "output/{}-inference/style_labels_prediction.txt".format(
            global_config.experiment_timestamp)
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
        with open(output_file_path, 'w') as output_file:
            for one_hot_label in style_label_predictions:
                output_file.write("{}\n".format(
                    one_hot_label.tolist().index(1)))

        output_file_path = "output/{}-inference/adversarial_labels_prediction.txt".format(
            global_config.experiment_timestamp)
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
        with open(output_file_path, 'w') as output_file:
            for one_hot_label in adversarial_label_predictions:
                output_file.write("{}\n".format(
                    one_hot_label.tolist().index(1)))

        logger.info("Inference run complete")

        sess.close()

    elif options.generate_novel_text:
        logger.info("Generating novel text")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.model_config_file),
                'r') as json_file:
            model_config_dict = json.load(json_file)
            mconf.init_from_dict(model_config_dict)
            logger.info("Restored model config from saved JSON")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.vocab_save_file), 'r') as json_file:
            word_index = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.index_to_label_dict_file),
                'r') as json_file:
            index_to_label_map = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.average_label_embeddings_file),
                'rb') as pickle_file:
            average_label_embeddings = pickle.load(pickle_file)

        global_config.vocab_size = len(word_index)
        inverse_word_index = {v: k for k, v in word_index.items()}

        num_labels = len(index_to_label_map)
        text_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=global_config.vocab_size,
            filters=global_config.tokenizer_filters)
        text_tokenizer.word_index = word_index
        data_processor.populate_word_blacklist(word_index)

        logger.info("Building model architecture ...")
        network = adversarial_autoencoder.AdversarialAutoencoder()
        encoder_embedding_matrix, decoder_embedding_matrix = get_word_embeddings(
            None, word_index)
        network.build_model(word_index, encoder_embedding_matrix,
                            decoder_embedding_matrix, num_labels)

        sess = tf_session_helper.get_tensorflow_session()

        for label_index in index_to_label_map:
            if options.label_index and label_index != options.label_index:
                continue

            style_embedding = np.asarray(
                average_label_embeddings[int(label_index)])
            generated_sequences, final_sequence_lengths = \
                network.generate_novel_sentences(
                    sess, style_embedding, options.num_sentences_to_generate, num_labels,
                    os.path.join(options.saved_model_path, global_config.model_save_file))

            # first trims the generates sentences down to the length the decoder returns
            # then trim any <eos> token
            trimmed_generated_sequences = \
                [[index for index in sequence
                  if index != global_config.predefined_word_index[global_config.eos_token]]
                 for sequence in [x[:(y - 1)] for (x, y) in zip(generated_sequences, final_sequence_lengths)]]

            generated_word_lists = \
                [data_processor.generate_words_from_indices(x, inverse_word_index)
                 for x in trimmed_generated_sequences]

            generated_sentences = [" ".join(x) for x in generated_word_lists]
            output_file_path = "output/{}-generation/generated_sentences_{}.txt".format(
                global_config.experiment_timestamp, label_index)
            os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
            with open(output_file_path, 'w') as output_file:
                for sentence in generated_sentences:
                    output_file.write(sentence + "\n")

            logger.info("Generated {} sentences of label {} at path {}".format(
                options.num_sentences_to_generate,
                index_to_label_map[label_index], output_file_path))

        sess.close()
        logger.info("Generation run complete")
示例#3
0
    def run_validation(self, options, num_labels, validation_sequences,
                       validation_sequence_lengths, validation_labels,
                       validation_actual_word_lists, all_style_embeddings,
                       shuffled_one_hot_labels, inverse_word_index,
                       current_epoch, sess):

        logger.info("Running Validation {}:".format(
            current_epoch // global_config.validation_interval))

        glove_model = content_preservation.load_glove_model(
            options.validation_embeddings_file_path)

        validation_style_transfer_scores = list()
        validation_content_preservation_scores = list()
        validation_word_overlap_scores = list()
        for i in range(num_labels):

            logger.info("validating label {}".format(i))

            label_embeddings = list()
            validation_sequences_to_transfer = list()
            validation_labels_to_transfer = list()
            validation_sequence_lengths_to_transfer = list()

            for k in range(len(all_style_embeddings)):
                if shuffled_one_hot_labels[k].tolist().index(1) == i:
                    label_embeddings.append(all_style_embeddings[k])

            for k in range(len(validation_sequences)):
                if validation_labels[k].tolist().index(1) != i:
                    validation_sequences_to_transfer.append(
                        validation_sequences[k])
                    validation_labels_to_transfer.append(validation_labels[k])
                    validation_sequence_lengths_to_transfer.append(
                        validation_sequence_lengths[k])

            style_embedding = np.mean(np.asarray(label_embeddings), axis=0)

            validation_batches = len(
                validation_sequences_to_transfer) // mconf.batch_size
            if len(validation_sequences_to_transfer) % mconf.batch_size:
                validation_batches += 1

            validation_generated_sequences = list()
            validation_generated_sequence_lengths = list()
            for val_batch_number in range(validation_batches):
                (start_index, end_index) = self.get_batch_indices(
                    batch_number=val_batch_number,
                    data_limit=len(validation_sequences_to_transfer))

                conditioning_embedding = np.tile(A=style_embedding,
                                                 reps=(end_index - start_index,
                                                       1))

                [validation_generated_sequences_batch, validation_sequence_lengths_batch] = \
                    self.run_batch(
                        sess, start_index, end_index,
                        [self.inference_output, self.final_sequence_lengths],
                        validation_sequences_to_transfer, validation_labels_to_transfer,
                        validation_sequence_lengths_to_transfer,
                        conditioning_embedding, True, False, 0, 0, current_epoch)
                validation_generated_sequences.extend(
                    validation_generated_sequences_batch)
                validation_generated_sequence_lengths.extend(
                    validation_sequence_lengths_batch)

            trimmed_generated_sequences = \
                [[index for index in sequence
                  if index != global_config.predefined_word_index[global_config.eos_token]]
                 for sequence in [x[:(y - 1)] for (x, y) in zip(
                    validation_generated_sequences, validation_generated_sequence_lengths)]]

            generated_word_lists = \
                [data_processor.generate_words_from_indices(x, inverse_word_index)
                 for x in trimmed_generated_sequences]

            generated_sentences = [" ".join(x) for x in generated_word_lists]

            output_file_path = "output/{}-training/validation_sentences_{}.txt".format(
                global_config.experiment_timestamp, i)
            os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
            with open(output_file_path, 'w') as output_file:
                for sentence in generated_sentences:
                    output_file.write(sentence + "\n")

            [style_transfer_score,
             confusion_matrix] = style_transfer.get_style_transfer_score(
                 options.classifier_saved_model_path, output_file_path, str(i),
                 None)
            logger.debug(
                "style_transfer_score: {}".format(style_transfer_score))
            logger.debug("confusion_matrix:\n{}".format(confusion_matrix))

            content_preservation_score = content_preservation.get_content_preservation_score(
                validation_actual_word_lists, generated_word_lists,
                glove_model)
            logger.debug("content_preservation_score: {}".format(
                content_preservation_score))

            word_overlap_score = content_preservation.get_word_overlap_score(
                validation_actual_word_lists, generated_word_lists)
            logger.debug("word_overlap_score: {}".format(word_overlap_score))

            validation_style_transfer_scores.append(style_transfer_score)
            validation_content_preservation_scores.append(
                content_preservation_score)
            validation_word_overlap_scores.append(word_overlap_score)

        aggregate_style_transfer = np.mean(
            np.asarray(validation_style_transfer_scores))
        logger.info(
            "Aggregate Style Transfer: {}".format(aggregate_style_transfer))

        aggregate_content_preservation = np.mean(
            np.asarray(validation_content_preservation_scores))
        logger.info("Aggregate Content Preservation: {}".format(
            aggregate_content_preservation))

        aggregate_word_overlap = np.mean(
            np.asarray(validation_word_overlap_scores))
        logger.info(
            "Aggregate Word Overlap: {}".format(aggregate_word_overlap))

        with open(global_config.validation_scores_path,
                  'a+') as validation_scores_file:
            validation_record = {
                "epoch": current_epoch,
                "style-transfer": aggregate_style_transfer,
                "content-preservation": aggregate_content_preservation,
                "word-overlap": aggregate_word_overlap
            }
            validation_scores_file.write(json.dumps(validation_record) + "\n")