def train_classifier_model(options):
    # Load data
    logger.info("Loading data...")

    [word_index, x, _, _, _] = \
        data_processor.get_text_sequences(
            options['text_file_path'], options['vocab_size'], global_config.classifier_vocab_save_path)

    x = np.asarray(x)

    [y, _] = data_processor.get_labels(options['label_file_path'], False)

    shuffle_indices = np.random.permutation(np.arange(len(y)))
    x_shuffled = x[shuffle_indices]
    y_shuffled = y[shuffle_indices]

    # Split train/test set
    dev_sample_index = -1 * int(0.01 * float(len(y)))
    x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[
        dev_sample_index:]
    y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[
        dev_sample_index:]

    del x, y, x_shuffled, y_shuffled

    logger.info("Vocabulary Size: {:d}".format(global_config.vocab_size))
    logger.info("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))

    # Training
    sess = tf_session_helper.get_tensorflow_session()
    with sess.as_default():
        cnn = TextCNN(sequence_length=x_train.shape[1],
                      num_classes=y_train.shape[1],
                      vocab_size=options['vocab_size'],
                      embedding_size=128,
                      filter_sizes=list(map(int, [3, 4, 5])),
                      num_filters=128,
                      l2_reg_lambda=0.0)

        # Define Training procedure
        global_step = tf.Variable(0, name="global_step", trainable=False)
        optimizer = tf.train.AdamOptimizer(1e-3)
        grads_and_vars = optimizer.compute_gradients(cnn.loss)
        train_op = optimizer.apply_gradients(grads_and_vars,
                                             global_step=global_step)

        # Output directory for models and summaries
        out_dir = global_config.classifier_save_directory
        logger.info("Writing to {}\n".format(out_dir))

        # Summaries for loss and accuracy
        loss_summary = tf.summary.scalar("loss", cnn.loss)
        acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

        # Train Summaries
        # train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
        train_summary_op = tf.summary.merge([loss_summary, acc_summary])
        train_summary_dir = os.path.join(out_dir, "summaries", "train")
        train_summary_writer = tf.summary.FileWriter(train_summary_dir,
                                                     sess.graph)

        # Dev summaries
        dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
        dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=1)

        # Write vocabulary
        with open(global_config.classifier_vocab_save_path, 'w') as json_file:
            json.dump(word_index, json_file)
            logger.info("Saved vocabulary")

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: 0.5
            }
            _, step, summaries, loss, accuracy = sess.run([
                train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy
            ], feed_dict)
            logger.info("step {}: loss {:g}, acc {:g}".format(
                step, loss, accuracy))
            train_summary_writer.add_summary(summaries, step)

        def dev_step(x_batch, y_batch, writer=None):
            """
            Evaluates model on a dev set
            """
            feed_dict = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: 1.0
            }
            step, summaries, loss, accuracy = sess.run(
                [global_step, dev_summary_op, cnn.loss, cnn.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            logger.info("{}: step {}, loss {:g}, acc {:g}".format(
                time_str, step, loss, accuracy))
            if writer:
                writer.add_summary(summaries, step)

        # Generate batches
        batches = data_processor.batch_iter(list(zip(x_train, y_train)),
                                            mconf.batch_size,
                                            options['training_epochs'])
        # Training loop. For each batch...
        for batch in batches:
            x_batch, y_batch = zip(*batch)
            train_step(x_batch, y_batch)
            current_step = tf.train.global_step(sess, global_step)
            if current_step % 100 == 0:
                logger.info("\nEvaluation:")
                dev_step(x_dev, y_dev, writer=dev_summary_writer)
                logger.info("")
            if current_step % 100 == 0:
                path = saver.save(sess,
                                  checkpoint_prefix,
                                  global_step=current_step)
                logger.info("Saved model checkpoint to {}\n".format(path))
def main(argv):
    options = Options()

    parser = argparse.ArgumentParser()
    parser.add_argument("--logging-level", type=str, default="INFO")
    run_mode = parser.add_mutually_exclusive_group(required=True)
    run_mode.add_argument("--train-model", action="store_true", default=False)
    run_mode.add_argument("--transform-text",
                          action="store_true",
                          default=False)
    run_mode.add_argument("--generate-novel-text",
                          action="store_true",
                          default=False)

    parser.parse_known_args(args=argv, namespace=options)
    if options.train_model:
        parser.add_argument("--vocab-size", type=int, default=1000)
        parser.add_argument("--training-epochs", type=int, default=10)
        parser.add_argument("--text-file-path", type=str, required=True)
        parser.add_argument("--label-file-path", type=str, required=True)
        parser.add_argument("--validation-text-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--validation-label-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--training-embeddings-file-path", type=str)
        parser.add_argument("--validation-embeddings-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--dump-embeddings",
                            action="store_true",
                            default=False)
        parser.add_argument("--classifier-saved-model-path",
                            type=str,
                            required=True)
    if options.transform_text:
        parser.add_argument("--saved-model-path", type=str, required=True)
        parser.add_argument("--evaluation-text-file-path",
                            type=str,
                            required=True)
        parser.add_argument("--evaluation-label-file-path",
                            type=str,
                            required=True)
    if options.generate_novel_text:
        parser.add_argument("--saved-model-path", type=str, required=True)
        parser.add_argument("--num-sentences-to-generate",
                            type=int,
                            default=1000,
                            required=True)
        parser.add_argument("--label-index",
                            type=int,
                            default=1000,
                            required=False)

    parser.parse_known_args(args=argv, namespace=options)

    global logger
    logger = log_initializer.setup_custom_logger(global_config.logger_name,
                                                 options.logging_level)

    if not (options.train_model or options.transform_text
            or options.generate_novel_text):
        logger.info("Nothing to do. Exiting ...")
        sys.exit(0)

    global_config.training_epochs = options.training_epochs
    logger.info("experiment_timestamp: {}".format(
        global_config.experiment_timestamp))

    # Train and save model
    if options.train_model:
        os.makedirs(global_config.save_directory)
        with open(global_config.model_config_file_path,
                  'w') as model_config_file:
            json.dump(obj=mconf.__dict__, fp=model_config_file, indent=4)
        logger.info("Saved model config to {}".format(
            global_config.model_config_file_path))

        # Retrieve all data
        logger.info("Reading data ...")
        [
            word_index, padded_sequences, text_sequence_lengths,
            one_hot_labels, num_labels, text_tokenizer, inverse_word_index
        ] = get_data(options)
        data_size = padded_sequences.shape[0]

        encoder_embedding_matrix, decoder_embedding_matrix = \
            get_word_embeddings(options.training_embeddings_file_path, word_index)

        # Build model
        logger.info("Building model architecture ...")
        network = adversarial_autoencoder.AdversarialAutoencoder()
        network.build_model(word_index, encoder_embedding_matrix,
                            decoder_embedding_matrix, num_labels)

        logger.info("Training model ...")
        sess = tf_session_helper.get_tensorflow_session()

        [_, validation_actual_word_lists, validation_sequences, validation_sequence_lengths] = \
            data_processor.get_test_sequences(
                options.validation_text_file_path, text_tokenizer, word_index, inverse_word_index)
        [_, validation_labels] = \
            data_processor.get_test_labels(options.validation_label_file_path, global_config.save_directory)

        network.train(sess, data_size, padded_sequences, text_sequence_lengths,
                      one_hot_labels, num_labels, word_index,
                      encoder_embedding_matrix, decoder_embedding_matrix,
                      validation_sequences, validation_sequence_lengths,
                      validation_labels, inverse_word_index,
                      validation_actual_word_lists, options)
        sess.close()

        logger.info("Training complete!")

    elif options.transform_text:
        # Enforce a particular style embedding and regenerate text
        logger.info("Transforming text style ...")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.model_config_file),
                'r') as json_file:
            model_config_dict = json.load(json_file)
            mconf.init_from_dict(model_config_dict)
            logger.info("Restored model config from saved JSON")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.vocab_save_file), 'r') as json_file:
            word_index = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.index_to_label_dict_file),
                'r') as json_file:
            index_to_label_map = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.average_label_embeddings_file),
                'rb') as pickle_file:
            average_label_embeddings = pickle.load(pickle_file)

        global_config.vocab_size = len(word_index)

        num_labels = len(index_to_label_map)
        text_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=global_config.vocab_size,
            filters=global_config.tokenizer_filters)
        text_tokenizer.word_index = word_index

        inverse_word_index = {v: k for k, v in word_index.items()}
        [actual_sequences, _, padded_sequences, text_sequence_lengths] = \
            data_processor.get_test_sequences(
                options.evaluation_text_file_path, text_tokenizer, word_index, inverse_word_index)
        [label_sequences, _] = \
            data_processor.get_test_labels(options.evaluation_label_file_path, options.saved_model_path)

        logger.info("Building model architecture ...")
        network = adversarial_autoencoder.AdversarialAutoencoder()
        encoder_embedding_matrix, decoder_embedding_matrix = get_word_embeddings(
            None, word_index)
        network.build_model(word_index, encoder_embedding_matrix,
                            decoder_embedding_matrix, num_labels)

        sess = tf_session_helper.get_tensorflow_session()

        total_nll = 0
        for i in range(num_labels):
            logger.info("Style chosen: {}".format(i))

            filtered_actual_sequences = list()
            filtered_padded_sequences = list()
            filtered_text_sequence_lengths = list()
            for k in range(len(actual_sequences)):
                if label_sequences[k] != i:
                    filtered_actual_sequences.append(actual_sequences[k])
                    filtered_padded_sequences.append(padded_sequences[k])
                    filtered_text_sequence_lengths.append(
                        text_sequence_lengths[k])

            style_embedding = np.asarray(average_label_embeddings[i])
            [generated_sequences, final_sequence_lengths, _, _, _, cross_entropy_scores] = \
                network.transform_sentences(
                    sess, filtered_padded_sequences, filtered_text_sequence_lengths, style_embedding,
                    num_labels, os.path.join(options.saved_model_path, global_config.model_save_file))
            nll = -np.mean(a=cross_entropy_scores, axis=0)
            total_nll += nll
            logger.info("NLL: {}".format(nll))

            actual_word_lists = \
                [data_processor.generate_words_from_indices(x, inverse_word_index)
                 for x in filtered_actual_sequences]

            execute_post_inference_operations(
                actual_word_lists, generated_sequences, final_sequence_lengths,
                inverse_word_index, global_config.experiment_timestamp, i)

            logger.info("Generation complete for label {}".format(i))

        logger.info("Mean NLL: {}".format(total_nll / num_labels))

        logger.info("Predicting labels from latent spaces ...")
        _, _, overall_label_predictions, style_label_predictions, adversarial_label_predictions, _ = \
            network.transform_sentences(
                sess, padded_sequences, text_sequence_lengths, average_label_embeddings[0],
                num_labels, os.path.join(options.saved_model_path, global_config.model_save_file))

        # write label predictions to file
        output_file_path = "output/{}-inference/overall_labels_prediction.txt".format(
            global_config.experiment_timestamp)
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
        with open(output_file_path, 'w') as output_file:
            for one_hot_label in overall_label_predictions:
                output_file.write("{}\n".format(
                    one_hot_label.tolist().index(1)))

        output_file_path = "output/{}-inference/style_labels_prediction.txt".format(
            global_config.experiment_timestamp)
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
        with open(output_file_path, 'w') as output_file:
            for one_hot_label in style_label_predictions:
                output_file.write("{}\n".format(
                    one_hot_label.tolist().index(1)))

        output_file_path = "output/{}-inference/adversarial_labels_prediction.txt".format(
            global_config.experiment_timestamp)
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
        with open(output_file_path, 'w') as output_file:
            for one_hot_label in adversarial_label_predictions:
                output_file.write("{}\n".format(
                    one_hot_label.tolist().index(1)))

        logger.info("Inference run complete")

        sess.close()

    elif options.generate_novel_text:
        logger.info("Generating novel text")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.model_config_file),
                'r') as json_file:
            model_config_dict = json.load(json_file)
            mconf.init_from_dict(model_config_dict)
            logger.info("Restored model config from saved JSON")

        with open(
                os.path.join(options.saved_model_path,
                             global_config.vocab_save_file), 'r') as json_file:
            word_index = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.index_to_label_dict_file),
                'r') as json_file:
            index_to_label_map = json.load(json_file)
        with open(
                os.path.join(options.saved_model_path,
                             global_config.average_label_embeddings_file),
                'rb') as pickle_file:
            average_label_embeddings = pickle.load(pickle_file)

        global_config.vocab_size = len(word_index)
        inverse_word_index = {v: k for k, v in word_index.items()}

        num_labels = len(index_to_label_map)
        text_tokenizer = tf.keras.preprocessing.text.Tokenizer(
            num_words=global_config.vocab_size,
            filters=global_config.tokenizer_filters)
        text_tokenizer.word_index = word_index
        data_processor.populate_word_blacklist(word_index)

        logger.info("Building model architecture ...")
        network = adversarial_autoencoder.AdversarialAutoencoder()
        encoder_embedding_matrix, decoder_embedding_matrix = get_word_embeddings(
            None, word_index)
        network.build_model(word_index, encoder_embedding_matrix,
                            decoder_embedding_matrix, num_labels)

        sess = tf_session_helper.get_tensorflow_session()

        for label_index in index_to_label_map:
            if options.label_index and label_index != options.label_index:
                continue

            style_embedding = np.asarray(
                average_label_embeddings[int(label_index)])
            generated_sequences, final_sequence_lengths = \
                network.generate_novel_sentences(
                    sess, style_embedding, options.num_sentences_to_generate, num_labels,
                    os.path.join(options.saved_model_path, global_config.model_save_file))

            # first trims the generates sentences down to the length the decoder returns
            # then trim any <eos> token
            trimmed_generated_sequences = \
                [[index for index in sequence
                  if index != global_config.predefined_word_index[global_config.eos_token]]
                 for sequence in [x[:(y - 1)] for (x, y) in zip(generated_sequences, final_sequence_lengths)]]

            generated_word_lists = \
                [data_processor.generate_words_from_indices(x, inverse_word_index)
                 for x in trimmed_generated_sequences]

            generated_sentences = [" ".join(x) for x in generated_word_lists]
            output_file_path = "output/{}-generation/generated_sentences_{}.txt".format(
                global_config.experiment_timestamp, label_index)
            os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
            with open(output_file_path, 'w') as output_file:
                for sentence in generated_sentences:
                    output_file.write(sentence + "\n")

            logger.info("Generated {} sentences of label {} at path {}".format(
                options.num_sentences_to_generate,
                index_to_label_map[label_index], output_file_path))

        sess.close()
        logger.info("Generation run complete")
Exemple #3
0
def get_style_transfer_score(classifier_saved_model_path, text_file_path,
                             label):
    with open(
            os.path.join(classifier_saved_model_path,
                         global_config.vocab_save_file), 'r') as json_file:
        word_index = json.load(json_file)
    vocab_size = len(word_index)

    text_tokenizer = tf.keras.preprocessing.text.Tokenizer(
        num_words=global_config.vocab_size,
        filters=global_config.tokenizer_filters)
    text_tokenizer.word_index = word_index

    with open(text_file_path) as text_file:
        actual_sequences = text_tokenizer.texts_to_sequences(text_file)
    trimmed_sequences = [[
        x if x < vocab_size else word_index[global_config.unk_token]
        for x in sequence
    ] for sequence in actual_sequences]
    text_sequences = tf.keras.preprocessing.sequence.pad_sequences(
        trimmed_sequences,
        maxlen=global_config.max_sequence_length,
        padding='post',
        truncating='post',
        value=word_index[global_config.eos_token])

    x_test = np.asarray(text_sequences)
    y_test = np.asarray([int(label)] * len(text_sequences))

    checkpoint_file = tf.train.latest_checkpoint(
        os.path.join(classifier_saved_model_path, "checkpoints"))
    graph = tf.Graph()
    with graph.as_default():
        sess = tf_session_helper.get_tensorflow_session()
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            # input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]

            # Tensors we want to evaluate
            predictions = graph.get_operation_by_name(
                "output/predictions").outputs[0]

            # Generate batches for one epoch
            batches = data_processor.batch_iter(list(x_test),
                                                mconf.batch_size,
                                                1,
                                                shuffle=False)

            # Collect the predictions here
            all_predictions = []

            for x_test_batch in batches:
                batch_predictions = sess.run(predictions, {
                    input_x: x_test_batch,
                    dropout_keep_prob: 1.0
                })
                all_predictions = np.concatenate(
                    [all_predictions, batch_predictions])

        sess.close()

    # Print accuracy if y_test is defined
    if y_test is not None:
        correct_predictions = float(sum(all_predictions == y_test))
        accuracy = correct_predictions / float(len(y_test))
        # f1_score = metrics.f1_score(y_true=y_test, y_pred=all_predictions)
        confusion_matrix = metrics.confusion_matrix(y_true=y_test,
                                                    y_pred=all_predictions)
        return [accuracy, confusion_matrix]

    logger.info("Nothing to evaluate")
    return [0.0, None]