def main(task2_model_id, task2_model_path):
    multi_config = MultiConfig()
    multi_config.is_training = False
    multi_config.dropout_rate = 0.0

    print("loading data...")
    dict_word2index = bpe.load_pickle(multi_config.word2index_path)
    tests_id, test_data = bd.load_test_data(multi_config.test_path)
    if task2_model_id != 4:
        test_X = bd.build_test_data(test_data, dict_word2index, multi_config.max_text_len)
    else:
        test_X = bd.build_test_data_HAN(test_data, dict_word2index, multi_config.num_sentences, multi_config.sequence_length)

    testset = MingLueTestData(test_X)
    test_loader = DataLoader(dataset=testset,
                             batch_size=multi_config.batch_size,
                             shuffle=False,
                             num_workers=multi_config.num_workers)
    
    multi_config.vocab_size = len(dict_word2index)
    print("loading model...")
    model2 = load_multi_model(task2_model_path, task2_model_id, multi_config)
    
    print("model loaded")

    print("predicting...")
    predicted_multi_labels = [[]]
    predicted_multi_labels = predict_multi_label(test_loader, model2, multi_config)
    generate_result_json(tests_id, predicted_multi_labels, multi_config.result_path)
def main(task1_model_id, task1_model_path):
    config = Config()
    multi_config = MultiConfig()
    config.is_training = False
    config.dropout_rate = 0.0
    multi_config.is_training = False
    multi_config.dropout_rate = 0.0
    # model_id = int(input("Please select a model(input model id):\n0: fastText\n1: TextCNN\n2: TextRCNN\nInput: "))

    print("loading data...")
    dict_word2index = bpe.load_pickle(config.word2index_path)
    if task1_model_id != 4:
        tests_id, test_data = bd.load_test_data(config.test_path)
        test_X = bd.build_test_data(test_data, dict_word2index,
                                    config.max_text_len)
    else:
        tests_id, test_data = bd.load_test_data(config.test_path)
        test_X = bd.build_test_data_HAN(test_data, dict_word2index,
                                        config.num_sentences,
                                        config.sequence_length)
    testset = MingLueTestData(test_X)
    test_loader = DataLoader(dataset=testset,
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=config.num_workers)

    config.vocab_size = len(dict_word2index)
    multi_config.vocab_size = len(dict_word2index)
    print("loading model...")
    model1 = load_model(task1_model_path, task1_model_id, config)

    print("model loaded")

    print("predicting...")
    predicted_labels = predict(test_loader, model1, config.has_cuda)
    predicted_multi_labels = [[]]
    generate_result_json(tests_id, predicted_labels, predicted_multi_labels,
                         config.result_path)
예제 #3
0
def main(rcnn_model_path, han_model_path):
    config = Config()
    config.is_training = False
    config.dropout_rate = 0.0

    print("loading data...")
    dict_word2index = bpe.load_pickle(config.word2index_path)
    tests_id, test_data = bd.load_test_data(config.test_path)
    test_X = bd.build_test_data(test_data, dict_word2index, config.max_text_len)

    testset = MingLueTestData(test_X)
    test_loader = DataLoader(dataset=testset,
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=config.num_workers)

    test_X_HAN = bd.build_test_data_HAN(test_data, dict_word2index, config.num_sentences, config.sequence_length)

    testset = MingLueTestData(test_X_HAN)
    test_loader_HAN = DataLoader(dataset=testset,
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=config.num_workers)

    
    config.vocab_size = len(dict_word2index)
    print("loading model...")

    rcnn_model = load_model(rcnn_model_path, 2, config)
    han_model = load_model(han_model_path, 4, config)
    print("model loaded")

    print("predicting...")
    predicted_labels = predict(test_loader, test_loader_HAN, rcnn_model, han_model,  config)

    generate_result_json(tests_id, predicted_labels, config.result_path)