コード例 #1
0
ファイル: run_predict.py プロジェクト: sun510001/RQ_Chatbot
def cla_predict(input, model_type):
    if model_type in TrainModelConfig.BERT_LIST:
        tokenizer = BertTokenizer.from_pretrained(PredictContext.TRANS_PATH,
                                                  do_lower_case=True)
        x = encode_examples_bert(input, tokenizer, model_type,
                                 train=False).batch(1)
    elif model_type in TrainModelConfig.ROBERTA_LIST:
        tokenizer = RobertaTokenizer.from_pretrained(PredictContext.TRANS_PATH,
                                                     do_lower_case=True)
        x = encode_examples_roberta(input, tokenizer, model_type,
                                    train=False).batch(1)
    test_result = np.zeros((input.shape[0], 3))

    for i in range(1, PredictContext.FOLD_NUM + 1):
        # evaluate the model and predict the test dataset
        test_count = 't' + str(i)
        print('evaluating {}/{} ...'.format(i, PredictContext.FOLD_NUM))
        model_ver = BuildModels(PredictContext.M_VER, model_type)
        model2 = model_ver.build_model_1()
        model2.load_weights('{}{}{}.h5'.format(PredictContext.checkpoint_path,
                                               PredictContext.M_VER,
                                               test_count))
        test_result += model2.predict(x,
                                      verbose=BertBaseUnCaseV1.SHOW_MODE_NUM)
        del model2
        gc.collect()

    nd_pre_y = test_result.argmax(axis=1)
    return nd_pre_y
コード例 #2
0
def train_process(fold_count):
    test_count = 't' + str(fold_count)
    if model_type in TrainModelConfigV2.BERT_LIST:
        train, _ = encode_examples_bert(train_data.loc[train_index], tokenizer, model_type)
        val, _ = encode_examples_bert(train_data.loc[val_index], tokenizer, model_type)
    elif (model_type == "roberta-base") or (model_type == "roberta-large"):
        train, _ = encode_examples_roberta(train_data.loc[train_index], tokenizer, model_type)
        val, _ = encode_examples_roberta(train_data.loc[val_index], tokenizer, model_type)

    ds_train_encoded = train.batch(BertBaseUnCaseV2.BATCH_SIZE)
    ds_val_encoded = val.batch(BertBaseUnCaseV2.BATCH_SIZE)

    print('training {}/{} ...'.format(fold_count, BertBaseUnCaseV2.FOLD_NUM))

    model_ver = BuildModels(BertBaseUnCaseV2.m_ver, model_type)
    model = model_ver.build_model_1(verbose=BertBaseUnCaseV2.SHOW_MODE)
    # training the model
    train_model(model, ds_train_encoded, ds_val_encoded, test_count, model_type, checkpoint_path)
    # loss_temp, acc_temp = model.evaluate(ds_test_encoded)
    # print("In training, temp_loss: {}; temp_acc: {}".format(loss_temp, acc_temp))

    K.clear_session()
    del(model, model_ver)
    gc.collect()
コード例 #3
0
ファイル: run_predict.py プロジェクト: sun510001/RQ_Chatbot
def predict_result(input, model_type):
    tokenizer = BertTokenizer.from_pretrained(PredictContextV2.TRANS_PATH,
                                              do_lower_case=True)
    if model_type in TrainModelConfigV2.BERT_LIST:
        x = encode_examples_bert(input, tokenizer, model_type,
                                 train=False).batch(1)
    elif model_type in TrainModelConfigV2.ROBERTA_LIST:
        x = encode_examples_roberta(input, tokenizer, model_type,
                                    train=False).batch(1)

    model_ver = BuildModels(PredictContextV2.M_VER, model_type)
    model2 = model_ver.build_model_1()
    model2.load_weights('{}{}{}.h5'.format(PredictContextV2.checkpoint_path,
                                           PredictContextV2.M_VER,
                                           PredictContextV2.MODEL_COUNT))
    test_result = model2.predict(x, verbose=BertBaseUnCaseV2.SHOW_MODE_NUM)
    test_index = test_result[:, 1].argmax()
    return test_index
コード例 #4
0
                df_train_load = triple_label_v3(pd.read_csv(BertBaseUnCaseV2.PATH), model_type)
            else:
                df_train_load = triple_label_v1(pd.read_csv(BertBaseUnCaseV2.PATH), model_type)

            # values_count = df_train_load['tri_label'].value_counts()
            # print("Literal, rq, sarcasm count:", values_count)

            bert_init = BertBaseUnCaseV2(model_type, BertBaseUnCaseV2.VER)
            model_name, trans_path, version, checkpoint_path = bert_init.model_init()
            train_data, test_data = split_dataset(df_train_load)

            # get tokenizer
            if model_type in TrainModelConfigV2.BERT_LIST:
                tokenizer = BertTokenizer.from_pretrained("{}{}".format(PATH_TRANS_INPUT, model_type),
                                                          do_lower_case=True)
                test, test_label = encode_examples_bert(test_data, tokenizer, model_type)
            elif model_type in TrainModelConfigV2.ROBERTA_LIST:
                tokenizer = RobertaTokenizer.from_pretrained("{}{}".format(PATH_TRANS_INPUT, model_type),
                                                             lowercase=True, add_prefix_space=True)
                test, test_label = encode_examples_roberta(test_data, tokenizer, model_type)

            ds_test_encoded = test.batch(BertBaseUnCaseV2.BATCH_SIZE)

            print("Starting training ... ")

            for each in TrainModelConfigV2.MODEL_LIST:
                test_result = np.zeros((len(test_label), BertBaseUnCaseV2.N_CLASS))
                BertBaseUnCaseV2.m_ver = each
                print(model_type, BertBaseUnCaseV2.m_ver)
                # processs = []
                if is_train == 0: