def main():
    project_dir = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)

    # Parse config arguments
    argparser = argparse.ArgumentParser(
        description=("Run the Siamese BiLSTM model with an added "
                     "matching layer for paraphase detection."))
    argparser.add_argument("mode", type=str,
                           choices=["train", "predict"],
                           help=("One of {train|predict}, to "
                                 "indicate what you want the model to do. "
                                 "If you pick \"predict\", then you must also "
                                 "supply the path to a pretrained model and "
                                 "DataIndexer to load."))
    argparser.add_argument("--model_load_dir", type=str,
                           help=("The path to a directory with checkpoints to "
                                 "load for evaluation or prediction. The "
                                 "latest checkpoint will be loaded."))
    argparser.add_argument("--dataindexer_load_path", type=str,
                           help=("The path to the DataIndexer fit on the "
                                 "train data, so we can properly index the "
                                 "test data for evaluation or prediction."))
    argparser.add_argument("--train_file", type=str,
                           default=os.path.join(project_dir,
                                                "data/processed/quora/"
                                                "train_cleaned_train_split.csv"),
                           help="Path to a file to train on.")
    argparser.add_argument("--val_file", type=str,
                           default=os.path.join(project_dir,
                                                "data/processed/quora/"
                                                "train_cleaned_val_split.csv"),
                           help="Path to a file to monitor validation acc. on.")
    argparser.add_argument("--test_file", type=str,
                           default=os.path.join(project_dir,
                                                "data/processed/quora/"
                                                "test_final.csv"))
    argparser.add_argument("--batch_size", type=int, default=128,
                           help="Number of instances per batch.")
    argparser.add_argument("--num_epochs", type=int, default=10,
                           help=("Number of epochs to perform in "
                                 "training."))
    argparser.add_argument("--early_stopping_patience", type=int, default=0,
                           help=("number of epochs with no validation "
                                 "accuracy improvement after which training "
                                 "will be stopped"))
    argparser.add_argument("--num_sentence_words", type=int, default=30,
                           help=("The maximum length of a sentence. Longer "
                                 "sentences will be truncated, and shorter "
                                 "ones will be padded."))
    argparser.add_argument("--word_embedding_dim", type=int, default=300,
                           help="Dimensionality of the word embedding layer")
    argparser.add_argument("--pretrained_embeddings_file_path", type=str,
                           help="Path to a file with pretrained embeddings.",
                           default=os.path.join(project_dir,
                                                "data/external/",
                                                "glove.6B.300d.txt"))
    argparser.add_argument("--fine_tune_embeddings", action="store_true",
                           help=("Whether to train the embedding layer "
                                 "(if True), or keep it fixed (False)."))
    argparser.add_argument("--rnn_hidden_size", type=int, default=256,
                           help=("The output dimension of the RNN."))
    argparser.add_argument("--share_encoder_weights", action="store_true",
                           help=("Whether to use the same encoder on both "
                                 "input sentences (thus sharing weights), "
                                 "or a different one for each sentence"))
    argparser.add_argument("--output_keep_prob", type=float, default=1.0,
                           help=("The proportion of RNN outputs to keep, "
                                 "where the rest are dropped out."))
    argparser.add_argument("--log_period", type=int, default=10,
                           help=("Number of steps between each summary "
                                 "op evaluation."))
    argparser.add_argument("--val_period", type=int, default=250,
                           help=("Number of steps between each evaluation of "
                                 "validation performance."))
    argparser.add_argument("--log_dir", type=str,
                           default=os.path.join(project_dir,
                                                "logs/"),
                           help=("Directory to save logs to."))
    argparser.add_argument("--save_period", type=int, default=250,
                           help=("Number of steps between each "
                                 "model checkpoint"))
    argparser.add_argument("--save_dir", type=str,
                           default=os.path.join(project_dir,
                                                "models/"),
                           help=("Directory to save model checkpoints to."))
    argparser.add_argument("--run_id", type=str, required=True,
                           help=("Identifying run ID for this run. If "
                                 "predicting, you probably want this "
                                 "to be the same as the train run_id"))
    argparser.add_argument("--model_name", type=str, required=True,
                           help=("Identifying model name for this run. If"
                                 "predicting, you probably want this "
                                 "to be the same as the train run_id"))
    argparser.add_argument("--reweight_predictions_for_kaggle", action="store_true",
                           help=("Only relevant when predicting. Whether to "
                                 "reweight the prediction probabilities to "
                                 "account for class proportion discrepancy "
                                 "between train and test."))

    config = argparser.parse_args()

    model_name = config.model_name
    run_id = config.run_id
    mode = config.mode

    # Get the data.
    batch_size = config.batch_size
    if mode == "train":
        # Read the train data from a file, and use it to index the
        # validation data
        data_manager = DataManager(STSInstance)
        num_sentence_words = config.num_sentence_words
        get_train_data_gen, train_data_size = data_manager.get_train_data_from_file(
            [config.train_file], max_lengths={"num_sentence_words": num_sentence_words})
        get_val_data_gen, val_data_size = data_manager.get_validation_data_from_file(
            [config.val_file], max_lengths={"num_sentence_words": num_sentence_words})
    else:
        # Load the fitted DataManager, and use it to index the test data
        logger.info("Loading pickled DataManager from {}".format(
            config.dataindexer_load_path))
        data_manager = pickle.load(open(config.dataindexer_load_path, "rb"))
        get_test_data_gen, test_data_size = data_manager.get_test_data_from_file(
            [config.test_file])

    vars(config)["word_vocab_size"] = data_manager.data_indexer.get_vocab_size()

    # Log the run parameters.
    log_dir = config.log_dir
    log_path = os.path.join(log_dir, model_name, run_id.zfill(2))
    logger.info("Writing logs to {}".format(log_path))
    if not os.path.exists(log_path):
        logger.info("log path {} does not exist, "
                    "creating it".format(log_path))
        os.makedirs(log_path)
    params_path = os.path.join(log_path, mode + "params.json")
    logger.info("Writing params to {}".format(params_path))
    with open(params_path, 'w') as params_file:
        json.dump(vars(config), params_file, indent=4)

    # Get the embeddings.
    embedding_manager = EmbeddingManager(data_manager.data_indexer)
    embedding_matrix = embedding_manager.get_embedding_matrix(
        config.word_embedding_dim,
        config.pretrained_embeddings_file_path)
    vars(config)["word_embedding_matrix"] = embedding_matrix

    # Initialize the model.
    model = SiameseMatchingBiLSTM(vars(config))
    model.build_graph()

    if mode == "train":
        # Train the model.
        num_epochs = config.num_epochs
        num_train_steps_per_epoch = int(math.ceil(train_data_size / batch_size))
        num_val_steps = int(math.ceil(val_data_size / batch_size))
        log_period = config.log_period
        val_period = config.val_period

        save_period = config.save_period
        save_dir = os.path.join(config.save_dir, model_name, run_id.zfill(2) + "/")
        save_path = os.path.join(save_dir, model_name + "-" + run_id.zfill(2))

        logger.info("Checkpoints will be written to {}".format(save_dir))
        if not os.path.exists(save_dir):
            logger.info("save path {} does not exist, "
                        "creating it".format(save_dir))
            os.makedirs(save_dir)

        logger.info("Saving fitted DataManager to {}".format(save_dir))
        data_manager_pickle_name = "{}-{}-DataManager.pkl".format(model_name,
                                                                  run_id.zfill(2))
        pickle.dump(data_manager,
                    open(os.path.join(save_dir, data_manager_pickle_name), "wb"))

        patience = config.early_stopping_patience
        model.train(get_train_instance_generator=get_train_data_gen,
                    get_val_instance_generator=get_val_data_gen,
                    batch_size=batch_size,
                    num_train_steps_per_epoch=num_train_steps_per_epoch,
                    num_epochs=num_epochs,
                    num_val_steps=num_val_steps,
                    save_path=save_path,
                    log_path=log_path,
                    log_period=log_period,
                    val_period=val_period,
                    save_period=save_period,
                    patience=patience)
    else:
        # Predict with the model
        model_load_dir = config.model_load_dir
        num_test_steps = int(math.ceil(test_data_size / batch_size))
        # Numpy array of shape (num_test_examples, 2)
        raw_predictions = model.predict(get_test_instance_generator=get_test_data_gen,
                                        model_load_dir=model_load_dir,
                                        batch_size=batch_size,
                                        num_test_steps=num_test_steps)

        # Remove the first column, so we're left with just the probabilities
        # that a question is a duplicate.
        is_duplicate_probabilities = np.delete(raw_predictions, 0, 1)

        # The class balance between kaggle train and test seems different.
        # This edits prediction probability to account for the discrepancy.
        # See: https://www.kaggle.com/c/quora-question-pairs/discussion/31179
        if config.reweight_predictions_for_kaggle:
            positive_weight = 0.165 / 0.37
            negative_weight = (1 - 0.165) / (1 - 0.37)
            is_duplicate_probabilities = ((positive_weight * is_duplicate_probabilities) /
                                          (positive_weight * is_duplicate_probabilities +
                                           negative_weight *
                                           (1 - is_duplicate_probabilities)))

        # Write the predictions to an output submission file
        output_predictions_path = os.path.join(log_path, model_name + "-" +
                                               run_id.zfill(2) +
                                               "-output_predictions.csv")
        logger.info("Writing predictions to {}".format(output_predictions_path))
        is_duplicate_df = pd.DataFrame(is_duplicate_probabilities)
        is_duplicate_df.to_csv(output_predictions_path, index_label="test_id",
                               header=["is_duplicate"])
Ejemplo n.º 2
0
class TestEmbeddingManager(DuplicateTestCase):
    @overrides
    def setUp(self):
        super(TestEmbeddingManager, self).setUp()
        self.write_vector_file()
        self.data_indexer = DataIndexer()
        self.data_indexer.add_word_to_index("word1")
        self.data_indexer.add_word_to_index("word2")
        self.data_indexer.is_fit = True
        self.embedding_dict = {"word1": np.array([5.1, 7.2, -0.2]),
                               "word2": np.array([0.8, 0.1, 0.9])}
        self.embedding_manager = EmbeddingManager(self.data_indexer)

    def test_get_embedding_matrix_reads_data_file(self):
        embed_mat = self.embedding_manager.get_embedding_matrix(
            3,
            pretrained_embeddings_file_path=self.VECTORS_FILE)
        assert_allclose(embed_mat[2], np.array([0.0, 1.1, 0.2]))
        assert_allclose(embed_mat[3], np.array([0.1, 0.4, -4.0]))

    def test_get_embedding_matrix_reads_dict(self):
        embed_mat = self.embedding_manager.get_embedding_matrix(
            3,
            pretrained_embeddings_dict=self.embedding_dict)
        assert_allclose(embed_mat[2], np.array([5.1, 7.2, -0.2]))
        assert_allclose(embed_mat[3], np.array([0.8, 0.1, 0.9]))

    def test_get_embedding_matrix_dict_overrides_file(self):
        embed_mat = self.embedding_manager.get_embedding_matrix(
            3,
            pretrained_embeddings_file_path=self.VECTORS_FILE,
            pretrained_embeddings_dict=self.embedding_dict)
        assert_allclose(embed_mat[2], np.array([5.1, 7.2, -0.2]))
        assert_allclose(embed_mat[3], np.array([0.8, 0.1, 0.9]))

    def test_get_embedding_matrix_reproducible(self):
        embed_mat_1_random = self.embedding_manager.get_embedding_matrix(100)
        embed_mat_2_random = self.embedding_manager.get_embedding_matrix(100)
        assert_allclose(embed_mat_1_random, embed_mat_2_random)

        embed_mat_1_file = self.embedding_manager.get_embedding_matrix(
            3,
            pretrained_embeddings_file_path=self.VECTORS_FILE)
        embed_mat_2_file = self.embedding_manager.get_embedding_matrix(
            3,
            pretrained_embeddings_file_path=self.VECTORS_FILE)
        assert_allclose(embed_mat_1_file, embed_mat_2_file)

        embed_mat_1_dict = self.embedding_manager.get_embedding_matrix(
            3,
            pretrained_embeddings_dict=self.embedding_dict)
        embed_mat_2_dict = self.embedding_manager.get_embedding_matrix(
            3,
            pretrained_embeddings_dict=self.embedding_dict)
        assert_allclose(embed_mat_1_dict, embed_mat_2_dict)

    def test_embedding_manager_errors(self):
        with self.assertRaises(ValueError):
            unfitted_data_indexer = DataIndexer()
            EmbeddingManager(unfitted_data_indexer)
        with self.assertRaises(ValueError):
            EmbeddingManager.initialize_random_matrix((19,))
        with self.assertRaises(ValueError):
            EmbeddingManager.initialize_random_matrix((19, 100, 100))
        with self.assertRaises(ValueError):
            self.embedding_manager.get_embedding_matrix(5.0)
        with self.assertRaises(ValueError):
            self.embedding_manager.get_embedding_matrix("5")
        with self.assertRaises(ValueError):
            self.embedding_manager.get_embedding_matrix(
                5,
                pretrained_embeddings_file_path=["some_path"])
        with self.assertRaises(ValueError):
            self.embedding_manager.get_embedding_matrix(
                5,
                pretrained_embeddings_dict=["list", [0.1, 0.2]])
        with self.assertRaises(ValueError):
            self.embedding_manager.get_embedding_matrix(
                5,
                pretrained_embeddings_file_path=self.VECTORS_FILE)
        with self.assertRaises(ValueError):
            self.embedding_manager.get_embedding_matrix(
                5,
                pretrained_embeddings_dict=self.embedding_dict)
        with self.assertRaises(ValueError):
            bad_dict = {"word1": np.array([0.1, 0.2]),
                        "word2": np.array([0.3, 0.4, 0.5])}
            self.embedding_manager.get_embedding_matrix(
                5,
                pretrained_embeddings_dict=bad_dict)
        with self.assertRaises(ValueError):
            bad_vectors_path = self.TEST_DIR + 'bad_vectors_file'
            with codecs.open(bad_vectors_path, 'w', 'utf-8') as vectors_file:
                vectors_file.write("word1 0.0 1.1 0.2\n")
                vectors_file.write("word2 0.1 0.4\n")
            self.embedding_manager.get_embedding_matrix(
                3,
                pretrained_embeddings_file_path=bad_vectors_path)
        with self.assertRaises(ValueError):
            bad_vectors_path = self.TEST_DIR + 'bad_vectors_file'
            with codecs.open(bad_vectors_path, 'w', 'utf-8') as vectors_file:
                vectors_file.write("word0 0.0\n")
                vectors_file.write("word1 0.0 1.1 0.2\n")
                vectors_file.write("word2 0.1 0.4\n")
            self.embedding_manager.get_embedding_matrix(
                3,
                pretrained_embeddings_file_path=bad_vectors_path)
class TestSiameseBiLSTM(DuplicateTestCase):
    @overrides
    def setUp(self):
        super(TestSiameseBiLSTM, self).setUp()
        self.write_duplicate_questions_train_file()
        self.write_duplicate_questions_validation_file()
        self.write_duplicate_questions_test_file()
        self.data_manager = DataManager(STSInstance)
        self.batch_size = 2
        self.get_train_gen, self.train_size = self.data_manager.get_train_data_from_file(
            [self.TRAIN_FILE])
        self.get_val_gen, self.val_size = self.data_manager.get_validation_data_from_file(
            [self.VALIDATION_FILE])
        self.get_test_gen, self.test_size = self.data_manager.get_test_data_from_file(
            [self.TEST_FILE])

        self.embedding_manager = EmbeddingManager(self.data_manager.data_indexer)
        self.word_embedding_dim = 5
        self.embedding_matrix = self.embedding_manager.get_embedding_matrix(
            self.word_embedding_dim)
        self.rnn_hidden_size = 6
        self.rnn_output_mode = "last"
        self.output_keep_prob = 1.0
        self.share_encoder_weights = True
        self.config_dict = {
            "mode": "train",
            "word_vocab_size": self.data_manager.data_indexer.get_vocab_size(),
            "word_embedding_dim": self.word_embedding_dim,
            "fine_tune_embeddings": False,
            "word_embedding_matrix": self.embedding_matrix,
            "rnn_hidden_size": self.rnn_hidden_size,
            "rnn_output_mode": self.rnn_output_mode,
            "output_keep_prob": self.output_keep_prob,
            "share_encoder_weights": self.share_encoder_weights
        }
        self.num_train_steps_per_epoch = int(math.ceil(self.train_size / self.batch_size))
        self.num_val_steps = int(math.ceil(self.val_size / self.batch_size))
        self.num_test_steps = int(math.ceil(self.test_size / self.batch_size))

    def test_default_does_not_crash(self):
        # Initialize the model
        model = SiameseBiLSTM(self.config_dict)
        model.build_graph()
        # Train the model
        model.train(get_train_instance_generator=self.get_train_gen,
                    get_val_instance_generator=self.get_val_gen,
                    batch_size=self.batch_size,
                    num_train_steps_per_epoch=self.num_train_steps_per_epoch,
                    num_epochs=2,
                    num_val_steps=self.num_val_steps,
                    save_path=self.TEST_DIR,
                    log_path=self.TEST_DIR,
                    log_period=2,
                    val_period=2,
                    save_period=2,
                    patience=0)

        tf.reset_default_graph()
        # Load and predict with the model
        self.config_dict["mode"] = "test"
        del self.config_dict["word_embedding_matrix"]
        loaded_model = SiameseBiLSTM(self.config_dict)
        loaded_model.build_graph()
        loaded_model.predict(get_test_instance_generator=self.get_test_gen,
                             model_load_dir=self.TEST_DIR,
                             batch_size=self.batch_size,
                             num_test_steps=self.num_test_steps)

    def test_mean_pool_does_not_crash(self):
        # Initialize the model
        self.config_dict["rnn_output_mode"] = "mean_pool"
        model = SiameseBiLSTM(self.config_dict)
        model.build_graph()
        # Train the model
        model.train(get_train_instance_generator=self.get_train_gen,
                    get_val_instance_generator=self.get_val_gen,
                    batch_size=self.batch_size,
                    num_train_steps_per_epoch=self.num_train_steps_per_epoch,
                    num_epochs=2,
                    num_val_steps=self.num_val_steps,
                    save_path=self.TEST_DIR,
                    log_path=self.TEST_DIR,
                    log_period=2,
                    val_period=2,
                    save_period=2,
                    patience=0)

        tf.reset_default_graph()
        # Load and predict with the model
        self.config_dict["mode"] = "test"
        del self.config_dict["word_embedding_matrix"]
        loaded_model = SiameseBiLSTM(self.config_dict)
        loaded_model.build_graph()
        loaded_model.predict(get_test_instance_generator=self.get_test_gen,
                             model_load_dir=self.TEST_DIR,
                             batch_size=self.batch_size,
                             num_test_steps=self.num_test_steps)

    def test_non_sharing_encoders_does_not_crash(self):
        # Initialize the model
        self.config_dict["share_encoder_weights"] = False
        model = SiameseBiLSTM(self.config_dict)
        model.build_graph()
        # Train the model
        model.train(get_train_instance_generator=self.get_train_gen,
                    get_val_instance_generator=self.get_val_gen,
                    batch_size=self.batch_size,
                    num_train_steps_per_epoch=self.num_train_steps_per_epoch,
                    num_epochs=2,
                    num_val_steps=self.num_val_steps,
                    save_path=self.TEST_DIR,
                    log_path=self.TEST_DIR,
                    log_period=2,
                    val_period=2,
                    save_period=2,
                    patience=0)

        tf.reset_default_graph()
        # Load and predict with the model
        self.config_dict["mode"] = "test"
        del self.config_dict["word_embedding_matrix"]
        loaded_model = SiameseBiLSTM(self.config_dict)
        loaded_model.build_graph()
        loaded_model.predict(get_test_instance_generator=self.get_test_gen,
                             model_load_dir=self.TEST_DIR,
                             batch_size=self.batch_size,
                             num_test_steps=self.num_test_steps)
Ejemplo n.º 4
0
def main():
    project_dir = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)

    # Parse config arguments
    argparser = argparse.ArgumentParser(
        description=("Run a baseline Siamese BiLSTM model "
                     "for paraphrase identification."))
    argparser.add_argument("mode",
                           type=str,
                           choices=["train", "predict"],
                           help=("One of {train|predict}, to "
                                 "indicate what you want the model to do. "
                                 "If you pick \"predict\", then you must also "
                                 "supply the path to a pretrained model and "
                                 "DataIndexer to load."))
    argparser.add_argument("--model_load_dir",
                           type=str,
                           help=("The path to a directory with checkpoints to "
                                 "load for evaluation or prediction. The "
                                 "latest checkpoint will be loaded."))
    argparser.add_argument("--dataindexer_load_path",
                           type=str,
                           help=("The path to the dataindexer fit on the "
                                 "train data, so we can properly index the "
                                 "test data for evaluation or prediction."))
    argparser.add_argument("--train_file",
                           type=str,
                           default=os.path.join(
                               project_dir, "data/processed/quora/"
                               "train_cleaned_train_split.csv"),
                           help="Path to a file to train on.")
    argparser.add_argument(
        "--val_file",
        type=str,
        default=os.path.join(
            project_dir, "data/processed/quora/"
            "train_cleaned_val_split.csv"),
        help="Path to a file to monitor validation acc. on.")
    argparser.add_argument("--test_file",
                           type=str,
                           default=os.path.join(
                               project_dir, "data/processed/quora/"
                               "test_final.csv"))
    argparser.add_argument("--batch_size",
                           type=int,
                           default=128,
                           help="Number of instances per batch.")
    argparser.add_argument("--num_epochs",
                           type=int,
                           default=10,
                           help=("Number of epochs to perform in "
                                 "training."))
    argparser.add_argument("--early_stopping_patience",
                           type=int,
                           default=0,
                           help=("number of epochs with no validation "
                                 "accuracy improvement after which training "
                                 "will be stopped"))
    argparser.add_argument("--num_sentence_words",
                           type=int,
                           default=30,
                           help=("The maximum length of a sentence. Longer "
                                 "sentences will be truncated, and shorter "
                                 "ones will be padded."))
    argparser.add_argument("--word_embedding_dim",
                           type=int,
                           default=300,
                           help="Dimensionality of the word embedding layer")
    argparser.add_argument("--pretrained_embeddings_file_path",
                           type=str,
                           help="Path to a file with pretrained embeddings.",
                           default=os.path.join(project_dir, "data/external/",
                                                "glove.6B.300d.txt"))
    argparser.add_argument("--fine_tune_embeddings",
                           action="store_true",
                           help=("Whether to train the embedding layer "
                                 "(if True), or keep it fixed (False)."))
    argparser.add_argument("--rnn_hidden_size",
                           type=int,
                           default=256,
                           help=("The output dimension of the RNN."))
    argparser.add_argument("--share_encoder_weights",
                           action="store_true",
                           help=("Whether to use the same encoder on both "
                                 "input sentences (thus sharing weights), "
                                 "or a different one for each sentence"))
    argparser.add_argument("--rnn_output_mode",
                           type=str,
                           default="last",
                           choices=["mean_pool", "last"],
                           help=("How to calculate the final sentence "
                                 "representation from the RNN outputs. "
                                 "\"mean_pool\" indicates that the outputs "
                                 "will be averaged (with respect to padding), "
                                 "and \"last\" indicates that the last "
                                 "relevant output will be used as the "
                                 "sentence representation."))
    argparser.add_argument("--output_keep_prob",
                           type=float,
                           default=1.0,
                           help=("The proportion of RNN outputs to keep, "
                                 "where the rest are dropped out."))
    argparser.add_argument("--log_period",
                           type=int,
                           default=10,
                           help=("Number of steps between each summary "
                                 "op evaluation."))
    argparser.add_argument("--val_period",
                           type=int,
                           default=250,
                           help=("Number of steps between each evaluation of "
                                 "validation performance."))
    argparser.add_argument("--log_dir",
                           type=str,
                           default=os.path.join(project_dir, "logs/"),
                           help=("Directory to save logs to."))
    argparser.add_argument("--save_period",
                           type=int,
                           default=250,
                           help=("Number of steps between each "
                                 "model checkpoint"))
    argparser.add_argument("--save_dir",
                           type=str,
                           default=os.path.join(project_dir, "models/"),
                           help=("Directory to save model checkpoints to."))
    argparser.add_argument("--run_id",
                           type=str,
                           required=True,
                           help=("Identifying run ID for this run. If "
                                 "predicting, you probably want this "
                                 "to be the same as the train run_id"))
    argparser.add_argument("--model_name",
                           type=str,
                           required=True,
                           help=("Identifying model name for this run. If"
                                 "predicting, you probably want this "
                                 "to be the same as the train run_id"))
    argparser.add_argument("--reweight_predictions_for_kaggle",
                           action="store_true",
                           help=("Only relevant when predicting. Whether to"
                                 "reweight the prediction probabilities to "
                                 "account for class proportion discrepancy "
                                 "between train and test."))

    config = argparser.parse_args()

    model_name = config.model_name
    run_id = config.run_id
    mode = config.mode

    # Get the data.
    batch_size = config.batch_size
    if mode == "train":
        # Read the train data from a file, and use it to index the validation data
        data_manager = DataManager(STSInstance)
        num_sentence_words = config.num_sentence_words
        get_train_data_gen, train_data_size = data_manager.get_train_data_from_file(
            [config.train_file],
            max_lengths={"num_sentence_words": num_sentence_words})
        get_val_data_gen, val_data_size = data_manager.get_validation_data_from_file(
            [config.val_file],
            max_lengths={"num_sentence_words": num_sentence_words})
    else:
        # Load the fitted DataManager, and use it to index the test data
        logger.info("Loading pickled DataManager "
                    "from {}".format(config.dataindexer_load_path))
        data_manager = pickle.load(open(config.dataindexer_load_path, "rb"))
        test_data_gen, test_data_size = data_manager.get_test_data_from_file(
            [config.test_file])

    vars(config)["word_vocab_size"] = data_manager.data_indexer.get_vocab_size(
    )

    # Log the run parameters.
    log_dir = config.log_dir
    log_path = os.path.join(log_dir, model_name, run_id.zfill(2))
    logger.info("Writing logs to {}".format(log_path))
    if not os.path.exists(log_path):
        logger.info("log path {} does not exist, "
                    "creating it".format(log_path))
        os.makedirs(log_path)
    params_path = os.path.join(log_path, mode + "params.json")
    logger.info("Writing params to {}".format(params_path))
    with open(params_path, 'w') as params_file:
        json.dump(vars(config), params_file, indent=4)

    # Get the embeddings.
    embedding_manager = EmbeddingManager(data_manager.data_indexer)
    embedding_matrix = embedding_manager.get_embedding_matrix(
        config.word_embedding_dim, config.pretrained_embeddings_file_path)
    vars(config)["word_embedding_matrix"] = embedding_matrix

    # Initialize the model.
    model = SiameseBiLSTM(vars(config))
    model.build_graph()

    if mode == "train":
        # Train the model.
        num_epochs = config.num_epochs
        num_train_steps_per_epoch = int(math.ceil(train_data_size /
                                                  batch_size))
        num_val_steps = int(math.ceil(val_data_size / batch_size))
        log_period = config.log_period
        val_period = config.val_period

        save_period = config.save_period
        save_dir = os.path.join(config.save_dir, model_name,
                                run_id.zfill(2) + "/")
        save_path = os.path.join(save_dir, model_name + "-" + run_id.zfill(2))

        logger.info("Checkpoints will be written to {}".format(save_dir))
        if not os.path.exists(save_dir):
            logger.info("save path {} does not exist, "
                        "creating it".format(save_dir))
            os.makedirs(save_dir)

        logger.info("Saving fitted DataManager to {}".format(save_dir))
        data_manager_pickle_name = "{}-{}-DataManager.pkl".format(
            model_name, run_id.zfill(2))
        pickle.dump(
            data_manager,
            open(os.path.join(save_dir, data_manager_pickle_name), "wb"))

        patience = config.early_stopping_patience
        model.train(get_train_instance_generator=get_train_data_gen,
                    get_val_instance_generator=get_val_data_gen,
                    batch_size=batch_size,
                    num_train_steps_per_epoch=num_train_steps_per_epoch,
                    num_epochs=num_epochs,
                    num_val_steps=num_val_steps,
                    save_path=save_path,
                    log_path=log_path,
                    log_period=log_period,
                    val_period=val_period,
                    save_period=save_period,
                    patience=patience)
    else:
        # Predict with the model
        model_load_dir = config.model_load_dir
        num_test_steps = int(math.ceil(test_data_size / batch_size))
        # Numpy array of shape (num_test_examples, 2)
        raw_predictions = model.predict(
            get_test_instance_generator=test_data_gen,
            model_load_dir=model_load_dir,
            batch_size=batch_size,
            num_test_steps=num_test_steps)
        # Remove the first column, so we're left with just the probabilities
        # that a question is a duplicate.
        is_duplicate_probabilities = np.delete(raw_predictions, 0, 1)

        # The class balance between kaggle train and test seems different.
        # This edits prediction probability to account for the discrepancy.
        # See: https://www.kaggle.com/c/quora-question-pairs/discussion/31179
        if config.reweight_predictions_for_kaggle:
            positive_weight = 0.165 / 0.37
            negative_weight = (1 - 0.165) / (1 - 0.37)
            is_duplicate_probabilities = (
                (positive_weight * is_duplicate_probabilities) /
                (positive_weight * is_duplicate_probabilities +
                 negative_weight * (1 - is_duplicate_probabilities)))

        # Write the predictions to an output submission file
        output_predictions_path = os.path.join(
            log_path,
            model_name + "-" + run_id.zfill(2) + "-output_predictions.csv")
        logger.info(
            "Writing predictions to {}".format(output_predictions_path))
        is_duplicate_df = pd.DataFrame(is_duplicate_probabilities)
        is_duplicate_df.to_csv(output_predictions_path,
                               index_label="test_id",
                               header=["is_duplicate"])