Exemple #1
0
def test_load_rte_model_2():
    vocab, embeddings = load_whole_glove(
        "../../resources/embeddings/glove/glove.6B.300d.txt")
    estimator = ESIMrte(name='esim_verify',
                        activation='relu',
                        batch_size=64,
                        lstm_layers=1,
                        n_outputs=3,
                        num_neurons=[250, 180, 900, 550, 180],
                        show_progress=1,
                        embedding=embeddings,
                        vocab_size=len(vocab))
    estimator.restore_model("../models/rte/claim_verification_esim.ckpt")
def embed_data_set_with_glove_2(data_set_path: str,
                                db: Union[str, FeverDocDB],
                                glove_path: str = None,
                                vocab_dict: Dict[str, int] = None,
                                glove_embeddings=None,
                                predicted: bool = True,
                                threshold_b_sent_num=None,
                                threshold_b_sent_size=50,
                                threshold_h_sent_size=50):
    if vocab_dict is None or glove_embeddings is None:
        vocab, glove_embeddings = load_whole_glove(glove_path)
        vocab_dict = vocab_map(vocab)
    logger = LogHelper.get_logger("embed_data_set_given_vocab")
    datas, labels = read_data_set_from_jsonl(data_set_path, db, predicted)
    heads_ids = single_sentence_set_2_ids_given_vocab(datas['h'], vocab_dict)
    logger.debug("Finished sentence to IDs for claims")
    bodies_ids = multi_sentence_set_2_ids_given_vocab(datas['b'], vocab_dict)
    logger.debug("Finished sentence to IDs for evidences")
    h_np, h_sent_sizes = ids_padding_for_single_sentence_set_given_size(
        heads_ids, threshold_h_sent_size)
    logger.debug("Finished padding claims")
    b_np, b_sizes, b_sent_sizes = ids_padding_for_multi_sentences_set(
        bodies_ids, threshold_b_sent_num, threshold_b_sent_size)
    logger.debug("Finished padding evidences")
    processed_data_set = {
        'data': {
            'h_np': h_np,
            'b_np': b_np,
            'h_sent_sizes': h_sent_sizes,
            'b_sent_sizes': b_sent_sizes,
            'b_sizes': b_sizes
        },
        'id': datas['id']
    }
    if 'paths' in datas:
        padded_paths_np = pad_paths(datas['paths'], threshold_b_sent_num)
        processed_data_set['data']['paths'] = padded_paths_np
    if labels is not None and len(labels) == len(processed_data_set['id']):
        processed_data_set['label'] = labels
    return processed_data_set, vocab_dict, glove_embeddings, threshold_b_sent_num, threshold_b_sent_size
def embed_data_set_with_glove_and_fasttext(data_set_path: str,
                                           db: Union[str, FeverDocDB],
                                           fasttext_model: Union[str,
                                                                 FastText],
                                           glove_path: str = None,
                                           vocab_dict: Dict[str, int] = None,
                                           glove_embeddings=None,
                                           predicted: bool = True,
                                           threshold_b_sent_num=None,
                                           threshold_b_sent_size=50,
                                           threshold_h_sent_size=50,
                                           is_snopes=False):
    assert vocab_dict is not None and glove_embeddings is not None or glove_path is not None, "Either vocab_dict and glove_embeddings, or glove_path should be not None"
    if vocab_dict is None or glove_embeddings is None:
        vocab, glove_embeddings = load_whole_glove(glove_path)
        vocab_dict = vocab_map(vocab)
    logger = LogHelper.get_logger("embed_data_set_given_vocab")
    datas, labels = read_data_set_from_jsonl(data_set_path,
                                             db,
                                             predicted,
                                             is_snopes=is_snopes)
    heads_ft_embeddings, fasttext_model = single_sentence_set_2_fasttext_embedded(
        datas['h'], fasttext_model)
    logger.debug("Finished sentence to FastText embeddings for claims")
    heads_ids = single_sentence_set_2_ids_given_vocab(datas['h'], vocab_dict)
    logger.debug("Finished sentence to IDs for claims")
    bodies_ft_embeddings, fasttext_model = multi_sentence_set_2_fasttext_embedded(
        datas['b'], fasttext_model)
    logger.debug("Finished sentence to FastText embeddings for evidences")
    bodies_ids = multi_sentence_set_2_ids_given_vocab(datas['b'], vocab_dict)
    logger.debug("Finished sentence to IDs for evidences")
    h_ft_np = fasttext_padding_for_single_sentence_set_given_size(
        heads_ft_embeddings, threshold_h_sent_size)
    logger.debug(
        "Finished padding FastText embeddings for claims. Shape of h_ft_np: {}"
        .format(str(h_ft_np.shape)))
    b_ft_np = fasttext_padding_for_multi_sentences_set(bodies_ft_embeddings,
                                                       threshold_b_sent_num,
                                                       threshold_b_sent_size)
    logger.debug(
        "Finished padding FastText embeddings for evidences. Shape of b_ft_np: {}"
        .format(str(b_ft_np.shape)))
    h_np, h_sent_sizes = ids_padding_for_single_sentence_set_given_size(
        heads_ids, threshold_h_sent_size)
    logger.debug("Finished padding claims")
    b_np, b_sizes, b_sent_sizes = ids_padding_for_multi_sentences_set(
        bodies_ids, threshold_b_sent_num, threshold_b_sent_size)
    logger.debug("Finished padding evidences")
    processed_data_set = {
        'data': {
            'h_np': h_np,
            'b_np': b_np,
            'h_ft_np': h_ft_np,
            'b_ft_np': b_ft_np,
            'h_sent_sizes': h_sent_sizes,
            'b_sent_sizes': b_sent_sizes,
            'b_sizes': b_sizes
        },
        'id': datas['id']
    }
    if labels is not None and len(labels) == len(processed_data_set['id']):
        processed_data_set['label'] = labels
    return processed_data_set, fasttext_model, vocab_dict, glove_embeddings, threshold_b_sent_num, threshold_b_sent_size
                                                   1)
        # valid_set['data']['h_ft_np'] = np.expand_dims(valid_set['data']['h_ft_np'], 1)

        X_dict = {
            'X_train': training_set['data'],
            'X_valid': valid_set['data'],
            'y_valid': valid_set['label'],
            'embedding': embeddings
        }
        estimator = get_estimator(Config.estimator_name, Config.ckpt_folder)
        estimator.fit(X_dict, training_set['label'])
        save_model(estimator, Config.model_folder, Config.pickle_name)
    elif args.mode == 'test':
        # testing mode
        estimator = load_model(Config.model_folder, Config.pickle_name)
        vocab, embeddings = load_whole_glove(Config.glove_path)
        vocab = vocab_map(vocab)
        # test_set, _, _, _, _, _ = embed_data_set_with_glove_and_fasttext(Config.test_set_file, Config.db_path,
        #                                                                 fasttext_model, vocab_dict=vocab,
        #                                                                 glove_embeddings=embeddings,
        #                                                                 threshold_b_sent_num=Config.max_sentences,
        #                                                                 threshold_b_sent_size=Config.max_sentence_size,
        #                                                                 threshold_h_sent_size=Config.max_sentence_size,
        #                                                                 is_snopes=is_snopes)
        test_set, _, _, _, _ = embed_data_set_with_glove_2(
            Config.test_set_file,
            Config.db_path,
            vocab_dict=vocab,
            glove_embeddings=embeddings,
            threshold_b_sent_num=Config.max_sentences,
            threshold_b_sent_size=Config.max_sentence_size,
def main(mode, config, estimator=None):
    LogHelper.setup()
    logger = LogHelper.get_logger(os.path.splitext(os.path.basename(__file__))[0] + "_" + mode)
    logger.info("model: " + mode + ", config: " + str(config))
    logger.info("scorer type: " + Config.estimator_name)
    logger.info("random seed: " + str(Config.seed))
    logger.info("ESIM arguments: " + str(Config.esim_hyper_param))
    # loading FastText takes a long time, so better pickle the loaded FastText model
    if os.path.splitext(Config.fasttext_path)[1] == '.p':
        with open(Config.fasttext_path, "rb") as ft_file:
            fasttext_model = pickle.load(ft_file)
    else:
        fasttext_model = Config.fasttext_path
    if mode == 'train':
        # # training mode
        training_set, fasttext_model, vocab, embeddings, _, _ = embed_data_set_with_glove_and_fasttext(
            Config.training_set_file, Config.db_path(), fasttext_model, glove_path=Config.glove_path,
            threshold_b_sent_num=Config.max_sentences, threshold_b_sent_size=Config.max_sentence_size,
            threshold_h_sent_size=Config.max_claim_size)
        h_sent_sizes = training_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        training_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        training_set['data']['h_sizes'] = h_sizes
        training_set['data']['h_np'] = np.expand_dims(training_set['data']['h_np'], 1)
        training_set['data']['h_ft_np'] = np.expand_dims(training_set['data']['h_ft_np'], 1)

        valid_set, _, _, _, _, _ = embed_data_set_with_glove_and_fasttext(Config.dev_set_file, Config.db_path(),
                                                                          fasttext_model, vocab_dict=vocab,
                                                                          glove_embeddings=embeddings,
                                                                          threshold_b_sent_num=Config.max_sentences,
                                                                          threshold_b_sent_size=Config.max_sentence_size,
                                                                          threshold_h_sent_size=Config.max_claim_size)
        del fasttext_model
        h_sent_sizes = valid_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        valid_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        valid_set['data']['h_sizes'] = h_sizes
        valid_set['data']['h_np'] = np.expand_dims(valid_set['data']['h_np'], 1)
        valid_set['data']['h_ft_np'] = np.expand_dims(valid_set['data']['h_ft_np'], 1)

        X_dict = {
            'X_train': training_set['data'],
            'X_valid': valid_set['data'],
            'y_valid': valid_set['label'],
            'embedding': embeddings
        }
        if estimator is None:
            estimator = get_estimator(Config.estimator_name, Config.ckpt_folder)
        estimator.fit(X_dict, training_set['label'])
        save_model(estimator, Config.model_folder, Config.pickle_name, logger)
    elif mode == 'test':
        # testing mode
        restore_param_required = estimator is None
        if estimator is None:
            estimator = load_model(Config.model_folder, Config.pickle_name)
            if estimator is None:
                estimator = get_estimator(Config.estimator_name, Config.ckpt_folder)
        vocab, embeddings = load_whole_glove(Config.glove_path)
        vocab = vocab_map(vocab)
        test_set, _, _, _, _, _ = embed_data_set_with_glove_and_fasttext(Config.test_set_file, Config.db_path(),
                                                                         fasttext_model, vocab_dict=vocab,
                                                                         glove_embeddings=embeddings,
                                                                         threshold_b_sent_num=Config.max_sentences,
                                                                         threshold_b_sent_size=Config.max_sentence_size,
                                                                         threshold_h_sent_size=Config.max_claim_size)
        del fasttext_model
        h_sent_sizes = test_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        test_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        test_set['data']['h_sizes'] = h_sizes
        test_set['data']['h_np'] = np.expand_dims(test_set['data']['h_np'], 1)
        test_set['data']['h_ft_np'] = np.expand_dims(test_set['data']['h_ft_np'], 1)
        x_dict = {
            'X_test': test_set['data'],
            'embedding': embeddings
        }
        predictions = estimator.predict(x_dict, restore_param_required)
        generate_submission(predictions, test_set['id'], Config.test_set_file, Config.submission_file())
        if 'label' in test_set:
            print_metrics(test_set['label'], predictions, logger)
    else:
        logger.error("Invalid argument --mode: " + mode + " Argument --mode should be either 'train’ or ’test’")
Exemple #6
0
def fever_app(caller):
    #parser = ArgumentParser()
    #parser.add_argument("--db-path", default="/local/fever-common/data/fever/fever.db")
    #parser.add_argument("--random-seed", default=1234)
    #parser.add_argument("--sentence-model", default="model/esim_0/sentence_retrieval_ensemble")
    #parser.add_argument("--words-cache", default="model/sentence")
    #parser.add_argument("--c-max-length", default=20)
    #parser.add_argument("--s-max-length", default=60)
    #parser.add_argument("--fasttext-path", default="data/fasttext/wiki.en.bin")
    #parser.add_argument("--train-data", default="data/fever/train.wiki7.jsonl")
    #parser.add_argument("--dev-data", default="data/fever/dev.wiki7.jsonl")
    #parser.add_argument("--test-data", default="data/fever/test.wiki7.jsonl")
    #parser.add_argument("--add-claim", default=True)

    args = Struct(
        **{
            "db_path": "/local/fever-common/data/fever/fever.db",
            "random_seed": 1234,
            "sentence_model": "model/esim_0/sentence_retrieval_ensemble",
            "words_cache": "model/sentence",
            "c_max_length": 20,
            "s_max_length": 60,
            "fasttext_path": "data/fasttext/wiki.en.bin",
            "train_data": "data/fever/train.wiki7.jsonl",
            "dev_data": "data/fever/dev.wiki7.jsonl",
            "test_data": "data/fever/test.wiki7.jsonl",
            "add_claim": True
        })

    # Setup logging
    LogHelper.setup()
    logger = LogHelper.get_logger("setup")
    logger.info("Logging started")

    # Set seeds
    logger.info("Set Seeds")
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    # Load GLove
    logger.info("Load GloVe")
    vocab, embeddings = load_whole_glove(Config.glove_path)
    vocab = vocab_map(vocab)

    # Document Retrieval
    logger.info("Setup document retrieval")
    retrieval = Doc_Retrieval(database_path=args.db_path,
                              add_claim=args.add_claim,
                              k_wiki_results=k_wiki)

    # Sentence Selection
    logger.info("Setup sentence loader")
    #words, iwords = get_iwords(args, retrieval)

    sentence_loader = SentenceDataLoader(fasttext_path=args.fasttext_path,
                                         db_filepath=args.db_path,
                                         h_max_length=args.c_max_length,
                                         s_max_length=args.s_max_length,
                                         reserve_embed=True)
    sentence_loader.load_models(vocab,
                                sentence_loader.inverse_word_dict(vocab))

    sargs = Config.sentence_retrieval_ensemble_param
    sargs.update(vars(args))
    sargs = Struct(**sargs)

    logger.info("Sentence ESIM ensemble")
    selections = [
        SentenceESIM(h_max_length=sargs.c_max_length,
                     s_max_length=sargs.s_max_length,
                     learning_rate=sargs.learning_rate,
                     batch_size=sargs.batch_size,
                     num_epoch=sargs.num_epoch,
                     model_store_dir=sargs.sentence_model,
                     embedding=sentence_loader.embed,
                     word_dict=sentence_loader.word_dict,
                     dropout_rate=sargs.dropout_rate,
                     num_units=sargs.num_lstm_units,
                     share_rnn=False,
                     activation=tf.nn.tanh,
                     namespace="model_{}".format(i))
        for i in range(sargs.num_model)
    ]

    for i in range(sargs.num_model):
        logger.info("Restore Model {}".format(i))
        model_store_path = os.path.join(args.sentence_model,
                                        "model{}".format(i + 1))
        if not os.path.exists(model_store_path):
            raise Exception("model must be trained before testing")
        selections[i].restore_model(
            os.path.join(model_store_path, "best_model.ckpt"))

    logger.info("Load FastText")
    fasttext_model = FastText.load_fasttext_format(Config.fasttext_path)

    # RTE
    logger.info("Setup RTE")
    rte_predictor = get_estimator(Config.estimator_name, Config.ckpt_folder)
    rte_predictor.embedding = embeddings

    logger.info("Restore RTE Model")
    rte_predictor.restore_model(rte_predictor.ckpt_path)

    def get_docs_line(line):
        nps, wiki_results, pages = retrieval.exact_match(line)
        line['noun_phrases'] = nps
        line['predicted_pages'] = pages
        line['wiki_results'] = wiki_results
        return line

    def get_docs(lines):
        return list(map(get_docs_line, lines))

    def get_sents(lines):
        indexes, location_indexes = sentence_loader.get_indexes(lines)
        all_predictions = []

        for i in range(sargs.num_model):
            predictions = []

            selection_model = selections[i]

            for test_index in indexes:
                prediction = selection_model.predict(test_index)
                predictions.append(prediction)

            all_predictions.append(predictions)

        ensembled_predicitons = scores_processing(all_predictions, args)
        processed_predictions, scores = post_processing(
            ensembled_predicitons, location_indexes)
        final_predictions = prediction_processing_no_reload(
            lines, processed_predictions)

        return final_predictions

    def run_rte(lines):
        test_set, _, _, _, _, _ = embed_claims(
            lines,
            args.db_path,
            fasttext_model,
            vocab_dict=vocab,
            glove_embeddings=embeddings,
            threshold_b_sent_num=Config.max_sentences,
            threshold_b_sent_size=Config.max_sentence_size,
            threshold_h_sent_size=Config.max_claim_size)

        h_sent_sizes = test_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        test_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        test_set['data']['h_sizes'] = h_sizes
        test_set['data']['h_np'] = np.expand_dims(test_set['data']['h_np'], 1)
        test_set['data']['h_ft_np'] = np.expand_dims(
            test_set['data']['h_ft_np'], 1)

        x_dict = {'X_test': test_set['data'], 'embedding': embeddings}

        predictions = rte_predictor.predict(x_dict, False)
        return predictions

    def process_claims(claims):
        print("CLAIMS LEN {}".format(len(claims)))
        claims = get_docs(claims)

        print("CLAIMS LEN {}".format(len(claims)))
        claims = get_sents(claims)

        print("CLAIMS LEN {}".format(len(claims)))
        predictions = run_rte(claims)

        print("PREDICTIONS LEN {}".format(len(predictions)))

        ret = []
        for idx in range(len(claims)):
            claim = claims[idx]
            prediction = predictions[idx]

            return_line = {
                "predicted_label": prediction_2_label(prediction),
                "predicted_evidence": claim["predicted_evidence"]
            }
            ret.append(return_line)
        return ret

    return caller(process_claims)
def main(mode: RTERunPhase, config=None, estimator=None):
    LogHelper.setup()
    logger = LogHelper.get_logger(
        os.path.splitext(os.path.basename(__file__))[0] + "_" + str(mode))
    if config is not None and isinstance(config, str):
        logger.info("model: " + str(mode) + ", config: " + str(config))
        Config.load_config(config)
    logger.info("scorer type: " + Config.estimator_name)
    logger.info("random seed: " + str(Config.seed))
    logger.info("ESIM arguments: " + str(Config.esim_hyper_param))
    # loading FastText takes long time, so better pickle the loaded FastText model
    if os.path.splitext(Config.fasttext_path)[1] == '.p':
        with open(Config.fasttext_path, "rb") as ft_file:
            fasttext_model = pickle.load(ft_file)
    else:
        fasttext_model = Config.fasttext_path
    if mode == RTERunPhase.train:
        # # training mode
        training_set, fasttext_model, vocab, embeddings = embed_data_set_with_glove_and_fasttext_claim_only(
            Config.training_set_file,
            fasttext_model,
            glove_path=Config.glove_path,
            threshold_h_sent_size=Config.max_sentence_size)
        h_sent_sizes = training_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        training_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        training_set['data']['h_sizes'] = h_sizes
        training_set['data']['h_np'] = np.expand_dims(
            training_set['data']['h_np'], 1)
        training_set['data']['h_ft_np'] = np.expand_dims(
            training_set['data']['h_ft_np'], 1)

        valid_set, _, _, _ = embed_data_set_with_glove_and_fasttext_claim_only(
            Config.dev_set_file,
            fasttext_model,
            vocab_dict=vocab,
            glove_embeddings=embeddings,
            threshold_h_sent_size=Config.max_sentence_size)
        del fasttext_model
        h_sent_sizes = valid_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        valid_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        valid_set['data']['h_sizes'] = h_sizes
        valid_set['data']['h_np'] = np.expand_dims(valid_set['data']['h_np'],
                                                   1)
        valid_set['data']['h_ft_np'] = np.expand_dims(
            valid_set['data']['h_ft_np'], 1)

        X_dict = {
            'X_train': training_set['data'],
            'X_valid': valid_set['data'],
            'y_valid': valid_set['label'],
            'embedding': embeddings
        }
        if estimator is None:
            estimator = get_estimator(Config.estimator_name,
                                      Config.ckpt_folder)
        if 'CUDA_VISIBLE_DEVICES' not in os.environ or not str(
                os.environ['CUDA_VISIBLE_DEVICES']).strip():
            os.environ['CUDA_VISIBLE_DEVICES'] = str(
                GPUtil.getFirstAvailable(maxLoad=1.0,
                                         maxMemory=1.0 -
                                         Config.max_gpu_memory)[0])
        estimator.fit(X_dict, training_set['label'])
        save_model(estimator, Config.model_folder, Config.pickle_name, logger)
    elif mode == 'test':
        # testing mode
        restore_param_required = estimator is None
        if estimator is None:
            estimator = load_model(Config.model_folder, Config.pickle_name)
            if estimator is None:
                estimator = get_estimator(Config.estimator_name,
                                          Config.ckpt_folder)
        vocab, embeddings = load_whole_glove(Config.glove_path)
        vocab = vocab_map(vocab)
        test_set, _, _, _ = embed_data_set_with_glove_and_fasttext_claim_only(
            Config.test_set_file,
            fasttext_model,
            vocab_dict=vocab,
            glove_embeddings=embeddings,
            threshold_h_sent_size=Config.max_sentence_size)
        del fasttext_model
        h_sent_sizes = test_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        test_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        test_set['data']['h_sizes'] = h_sizes
        test_set['data']['h_np'] = np.expand_dims(test_set['data']['h_np'], 1)
        test_set['data']['h_ft_np'] = np.expand_dims(
            test_set['data']['h_ft_np'], 1)
        x_dict = {'X_test': test_set['data'], 'embedding': embeddings}
        if 'CUDA_VISIBLE_DEVICES' not in os.environ or not str(
                os.environ['CUDA_VISIBLE_DEVICES']).strip():
            os.environ['CUDA_VISIBLE_DEVICES'] = str(
                GPUtil.getFirstAvailable(maxLoad=1.0,
                                         maxMemory=1.0 -
                                         Config.max_gpu_memory)[0])
        predictions = estimator.predict(x_dict, restore_param_required)
        generate_submission(predictions, test_set['id'], Config.test_set_file,
                            Config.submission_file)
        if 'label' in test_set:
            print_metrics(test_set['label'], predictions, logger)
    return estimator
Exemple #8
0
def main(mode, config, estimator=None):
    LogHelper.setup()
    logger = LogHelper.get_logger(
        os.path.splitext(os.path.basename(__file__))[0] + "_" + mode)
    logger.info("model: " + mode + ", config: " + str(config))
    if hasattr(Config, 'use_inter_evidence_comparison'):
        use_inter_evidence_comparison = Config.use_inter_evidence_comparison
    else:
        use_inter_evidence_comparison = False
    if hasattr(Config, 'use_claim_evidences_comparison'):
        use_claim_evidences_comparison = Config.use_claim_evidences_comparison
    else:
        use_claim_evidences_comparison = False
    if hasattr(Config, 'use_extra_features'):
        use_extra_features = Config.use_extra_features
    else:
        use_extra_features = False
    if hasattr(Config, 'use_numeric_feature'):
        use_numeric_feature = Config.use_numeric_feature
    else:
        use_numeric_feature = False
    logger.info("scorer type: " + Config.estimator_name)
    logger.info("random seed: " + str(Config.seed))
    logger.info("ESIM arguments: " + str(Config.esim_end_2_end_hyper_param))
    logger.info("use_inter_sentence_comparison: " +
                str(use_inter_evidence_comparison))
    logger.info("use_extra_features: " + str(use_extra_features))
    logger.info("use_numeric_feature: " + str(use_numeric_feature))
    logger.info("use_claim_evidences_comparison: " +
                str(use_claim_evidences_comparison))
    if mode == 'train':
        # # training mode
        if hasattr(Config, 'training_dump') and os.path.exists(
                Config.training_dump):
            with open(Config.training_dump, 'rb') as f:
                (X_dict, y_train) = pickle.load(f)
        else:
            training_set, vocab, embeddings, _, _ = embed_data_set_with_glove_2(
                Config.training_set_file,
                Config.db_path,
                glove_path=Config.glove_path,
                threshold_b_sent_num=Config.max_sentences,
                threshold_b_sent_size=Config.max_sentence_size,
                threshold_h_sent_size=Config.max_sentence_size)
            h_sent_sizes = training_set['data']['h_sent_sizes']
            h_sizes = np.ones(len(h_sent_sizes), np.int32)
            training_set['data']['h_sent_sizes'] = np.expand_dims(
                h_sent_sizes, 1)
            training_set['data']['h_sizes'] = h_sizes
            training_set['data']['h_np'] = np.expand_dims(
                training_set['data']['h_np'], 1)

            valid_set, _, _, _, _ = embed_data_set_with_glove_2(
                Config.dev_set_file,
                Config.db_path,
                vocab_dict=vocab,
                glove_embeddings=embeddings,
                threshold_b_sent_num=Config.max_sentences,
                threshold_b_sent_size=Config.max_sentence_size,
                threshold_h_sent_size=Config.max_sentence_size)
            h_sent_sizes = valid_set['data']['h_sent_sizes']
            h_sizes = np.ones(len(h_sent_sizes), np.int32)
            valid_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
            valid_set['data']['h_sizes'] = h_sizes
            valid_set['data']['h_np'] = np.expand_dims(
                valid_set['data']['h_np'], 1)
            if use_extra_features:
                assert hasattr(
                    Config, 'feature_path'
                ), "Config should has feature_path if Config.use_feature is True"
                training_claim_features, training_evidence_features = load_feature_by_data_set(
                    Config.training_set_file, Config.feature_path,
                    Config.max_sentences)
                valid_claim_features, valid_evidence_features = load_feature_by_data_set(
                    Config.dev_set_file, Config.feature_path,
                    Config.max_sentences)
                training_set['data']['h_feats'] = training_claim_features
                training_set['data']['b_feats'] = training_evidence_features
                valid_set['data']['h_feats'] = valid_claim_features
                valid_set['data']['b_feats'] = valid_evidence_features
            if use_numeric_feature:
                training_num_feat = number_feature(Config.training_set_file,
                                                   Config.db_path,
                                                   Config.max_sentences)
                valid_num_feat = number_feature(Config.dev_set_file,
                                                Config.db_path,
                                                Config.max_sentences)
                training_set['data']['num_feat'] = training_num_feat
                valid_set['data']['num_feat'] = valid_num_feat
            if use_inter_evidence_comparison:
                training_concat_sent_indices, training_concat_sent_sizes = generate_concat_indices_for_inter_evidence(
                    training_set['data']['b_np'],
                    training_set['data']['b_sent_sizes'],
                    Config.max_sentence_size, Config.max_sentences)
                training_set['data'][
                    'b_concat_indices'] = training_concat_sent_indices
                training_set['data'][
                    'b_concat_sizes'] = training_concat_sent_sizes
                valid_concat_sent_indices, valid_concat_sent_sizes = generate_concat_indices_for_inter_evidence(
                    valid_set['data']['b_np'],
                    valid_set['data']['b_sent_sizes'],
                    Config.max_sentence_size, Config.max_sentences)
                valid_set['data'][
                    'b_concat_indices'] = valid_concat_sent_indices
                valid_set['data']['b_concat_sizes'] = valid_concat_sent_sizes
            if use_claim_evidences_comparison:
                training_all_evidences_indices, training_all_evidences_sizes = generate_concat_indices_for_claim(
                    training_set['data']['b_np'],
                    training_set['data']['b_sent_sizes'],
                    Config.max_sentence_size, Config.max_sentences)
                training_set['data'][
                    'b_concat_indices_for_h'] = training_all_evidences_indices
                training_set['data'][
                    'b_concat_sizes_for_h'] = training_all_evidences_sizes
                valid_all_evidences_indices, valid_all_evidences_sizes = generate_concat_indices_for_claim(
                    valid_set['data']['b_np'],
                    valid_set['data']['b_sent_sizes'],
                    Config.max_sentence_size, Config.max_sentences)
                valid_set['data'][
                    'b_concat_indices_for_h'] = valid_all_evidences_indices
                valid_set['data'][
                    'b_concat_sizes_for_h'] = valid_all_evidences_sizes
            X_dict = {
                'X_train': training_set['data'],
                'X_valid': valid_set['data'],
                'y_valid': valid_set['label'],
                'embedding': embeddings
            }
            y_train = training_set['label']
            if hasattr(Config, 'training_dump'):
                with open(Config.training_dump, 'wb') as f:
                    pickle.dump((X_dict, y_train),
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        if estimator is None:
            estimator = get_estimator(Config.estimator_name,
                                      Config.ckpt_folder)
        estimator.fit(X_dict, y_train)
        save_model(estimator, Config.model_folder, Config.pickle_name, logger)
    elif mode == 'test':
        # testing mode
        restore_param_required = estimator is None
        if estimator is None:
            estimator = load_model(Config.model_folder, Config.pickle_name)
        vocab, embeddings = load_whole_glove(Config.glove_path)
        vocab = vocab_map(vocab)
        test_set, _, _, _, _ = embed_data_set_with_glove_2(
            Config.test_set_file,
            Config.db_path,
            vocab_dict=vocab,
            glove_embeddings=embeddings,
            threshold_b_sent_num=Config.max_sentences,
            threshold_b_sent_size=Config.max_sentence_size,
            threshold_h_sent_size=Config.max_sentence_size)
        h_sent_sizes = test_set['data']['h_sent_sizes']
        h_sizes = np.ones(len(h_sent_sizes), np.int32)
        test_set['data']['h_sent_sizes'] = np.expand_dims(h_sent_sizes, 1)
        test_set['data']['h_sizes'] = h_sizes
        test_set['data']['h_np'] = np.expand_dims(test_set['data']['h_np'], 1)
        if use_extra_features:
            assert hasattr(
                Config, 'feature_path'
            ), "Config should has feature_path if Config.use_feature is True"
            test_claim_features, test_evidence_features = load_feature_by_data_set(
                Config.test_set_file, Config.feature_path,
                Config.max_sentences)
            test_set['data']['h_feats'] = test_claim_features
            test_set['data']['b_feats'] = test_evidence_features
        if use_numeric_feature:
            test_num_feat = number_feature(Config.test_set_file,
                                           Config.db_path,
                                           Config.max_sentences)
            test_set['data']['num_feat'] = test_num_feat
        x_dict = {'X_test': test_set['data'], 'embedding': embeddings}
        if use_inter_evidence_comparison:
            test_concat_sent_indices, test_concat_sent_sizes = generate_concat_indices_for_inter_evidence(
                test_set['data']['b_np'], test_set['data']['b_sent_sizes'],
                Config.max_sentence_size, Config.max_sentences)
            test_set['data']['b_concat_indices'] = test_concat_sent_indices
            test_set['data']['b_concat_sizes'] = test_concat_sent_sizes
        if use_claim_evidences_comparison:
            test_all_evidences_indices, test_all_evidences_sizes = generate_concat_indices_for_claim(
                test_set['data']['b_np'], test_set['data']['b_sent_sizes'],
                Config.max_sentence_size, Config.max_sentences)
            test_set['data'][
                'b_concat_indices_for_h'] = test_all_evidences_indices
            test_set['data']['b_concat_sizes_for_h'] = test_all_evidences_sizes
        predictions = estimator.predict(x_dict, restore_param_required)
        generate_submission(predictions, Config.test_set_file,
                            Config.submission_file)
        if 'label' in test_set:
            print_metrics(test_set['label'], predictions, logger)
    else:
        logger.error("Invalid argument --mode: " + mode +
                     " Argument --mode should be either 'train’ or ’test’")
    return estimator