コード例 #1
0
def train():
    run_this_model = available_models[2]
    text_match_model = ModelFactory.make_model(run_this_model)
    hyperparams = net_conf.get_hyperparams(run_this_model)
    dataset_name = available_datasets[0]
    dataset_params = params.get_dataset_params(dataset_name)
    tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #2
0
def main():
    run_this_model = available_models[2]
    text_match_model = ModelFactory.make_model(run_this_model)
    hyperparams = net_conf.get_hyperparams(run_this_model)
    hyperparams.batch_size = 1
    dataset_name = available_datasets[2]
    dataset_params = params.get_dataset_params(dataset_name)
    tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #3
0
def apply():
    model_name = available_models[1]
    text_match_model = ModelFactory.make_model(model_name)
    model_url = ''
    text_match_model.load(model_url)
    hyperparams = net_conf.get_hyperparams(model_name)
    dataset_name = available_datasets[0]
    dataset_params = params.get_dataset_params(dataset_name)

    text_match_model.setup(hyperparams, dataset_params)
    text_match_model.evaluate_generator()
コード例 #4
0
def tune_enc_layer_num_TEBLDModel():
    run_this_model = available_models[2]
    model_full_name = model_name_abbr_full[run_this_model]
    print('============ ' + model_full_name + ' tune enc layer num ============')
    enc_layer_nums = [1, 2, 3, 4, 5, 6]
    for layer_num in enc_layer_nums:
        text_match_model = ModelFactory.make_model(run_this_model)
        hyperparams = net_conf.get_hyperparams(run_this_model)
        hyperparams.layers_num = layer_num
        dataset_name = available_datasets[0]
        dataset_params = params.get_dataset_params(dataset_name)
        tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #5
0
def tune_dropout_rate_REBLDModel():
    model_name = available_models[3]
    model_full_name = model_name_abbr_full[model_name]
    print('============ ' + model_full_name + ' tune dropout rate ============')
    p_dropouts = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
    for p_dropout in p_dropouts:
        text_match_model = ModelFactory.make_model(model_name)
        hyperparams = net_conf.get_hyperparams(model_name)
        hyperparams.lstm_p_dropout = p_dropout
        hyperparams.dense_p_dropout = p_dropout
        dataset_name = available_datasets[0]
        dataset_params = params.get_dataset_params(dataset_name)
        tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #6
0
def tune_layer_num_SBLDModel():
    run_this_model = available_models[1]
    model_full_name = model_name_abbr_full[run_this_model]
    print('============ ' + model_full_name + ' tune layer num ============')
    # RNMTPlusEncoderBiLSTMDenseModel | StackedBiLSTMDenseModel
    layer_nums = [0, 1, 2, 3]
    for num in layer_nums:
        text_match_model = ModelFactory.make_model(run_this_model)
        hyperparams = net_conf.get_hyperparams(run_this_model)
        hyperparams.bilstm_retseq_layer_num = num
        dataset_name = available_datasets[0]
        dataset_params = params.get_dataset_params(dataset_name)
        tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #7
0
def tune_state_dim_SBLDModel():
    run_this_model = available_models[1]
    model_full_name = model_name_abbr_full[run_this_model]
    print('============ ' + model_full_name + ' tune hidden state dim num ============')
    # RNMTPlusEncoderBiLSTMDenseModel | StackedBiLSTMDenseModel
    # The hidden state dim of LSTM should have a certain relationship with the word emb dim.
    # Information will be lost if dim is set to small.
    state_dims = [100, 200, 300, 400, 500, 600, 700]
    for state_dim in state_dims:
        text_match_model = ModelFactory.make_model(run_this_model)
        hyperparams = net_conf.get_hyperparams(run_this_model)
        hyperparams.state_dim = state_dim
        dataset_name = available_datasets[0]
        dataset_params = params.get_dataset_params(dataset_name)
        tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #8
0
def tune_dropout_rate_SBLDModel():
    model_name = available_models[1]
    model_full_name = model_name_abbr_full[model_name]
    print('============ ' + model_full_name + ' tune dropout rate ============')
    # Don't set dropout rate too large, because it will cause information loss.
    # According to previous experiment: lstm rate >= 0.5, 0 <= dense rate <= 0.2
    lstm_p_dropouts = [0.5, 0.6, 0.7]
    dense_p_dropouts = [0, 0.1, 0.2]
    for lstm_rate in lstm_p_dropouts:
        for dense_rate in dense_p_dropouts:
            text_match_model = ModelFactory.make_model(model_name)
            hyperparams = net_conf.get_hyperparams(model_name)
            hyperparams.lstm_p_dropout = lstm_rate
            hyperparams.dense_p_dropout = dense_rate
            dataset_name = available_datasets[0]
            dataset_params = params.get_dataset_params(dataset_name)
            tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #9
0
def tune_l2_lambda_SBLDModel():
    run_this_model = available_models[1]
    model_full_name = model_name_abbr_full[run_this_model]
    print('============ ' + model_full_name + ' tune l2 lambda ============')
    # RNMTPlusEncoderBiLSTMDenseModel | StackedBiLSTMDenseModel
    kernel_l2_lambdas = [1e-5, 1e-4]
    recurrent_l2_lambdas = [1e-5, 1e-4]
    bias_l2_lambdas = [1e-5, 1e-4]
    activity_l2_lambdas = [0, 1e-5, 1e-4]
    for kernel_l2_lambda in kernel_l2_lambdas:
        for recurrent_l2_lambda in recurrent_l2_lambdas:
            for bias_l2_lambda in bias_l2_lambdas:
                for activity_l2_lambda in activity_l2_lambdas:
                    text_match_model = ModelFactory.make_model(run_this_model)
                    hyperparams = net_conf.get_hyperparams(run_this_model)
                    hyperparams.kernel_l2_lambda = kernel_l2_lambda
                    hyperparams.recurrent_l2_lambda = recurrent_l2_lambda
                    hyperparams.bias_l2_lambda = bias_l2_lambda
                    hyperparams.activity_l2_lambda = activity_l2_lambda
                    dataset_name = available_datasets[0]
                    dataset_params = params.get_dataset_params(dataset_name)
                    tools.train_model(text_match_model, hyperparams, dataset_params)
コード例 #10
0
        for in_out_pair in generate_in_out_pair_file(fname, tokenizer):
            # 每次生成一个批的数据,每次返回固定相同数目的样本
            if batch_samples_count < batch_size - 1:
                in_out_pairs.append(in_out_pair)
                batch_samples_count += 1
            else:
                in_out_pairs.append(in_out_pair)
                X, y = process_format_model_in(in_out_pairs, max_len, batch_size, pad, cut)
                yield X, y
                in_out_pairs = list()
                batch_samples_count = 0


if __name__ == '__main__':
    cikm_en = available_datasets[0]
    cikm_en = params.get_dataset_params(cikm_en)

    # ========== test _load_vectors() function ==========
    needed_word2vec = _load_vectors(cikm_en.fastText_en_pretrained_wiki_word_vecs_url,
                                    head_n=50)
    for word, vector in needed_word2vec.items():
        print(word, end=' ')
        for value in vector:
            print(value, end=' ')
        print()

    # ========== test _read_words() function ==========
    all_distinct_words = _read_words(cikm_en.processed_en_train_url)
    for word in all_distinct_words:
        print(word, end=' | ')
    print('\ntotal distinct words number:', len(all_distinct_words))
コード例 #11
0
                elif label == 0:
                    negative_count += 1
                    total_count += 1

    tmp = fname.split(os.path.sep)
    fname = tmp[len(tmp) - 1]
    print('==========', fname, '==========')
    print('total count:', total_count)
    print('positive count:', positive_count)
    print('negative count:', negative_count)


if __name__ == '__main__':
    # ========== test _remove_symbols() func ==========
    # str1 = "'Random Number' is what I don't like at all."
    # str1 = "I don't like 'Random Number'."
    # str1 = "I don't like 'Random Number' at all"
    # print(remove_symbols(str1, params.MATCH_SINGLE_QUOTE_STR))

    # ========== test data_statistic() func ==========
    cikm_en = available_datasets[0]
    cikm_en = params.get_dataset_params(cikm_en)
    fname = cikm_en.raw_url
    data_statistic(fname)
    fname = cikm_en.train_url
    data_statistic(fname)
    fname = cikm_en.val_url
    data_statistic(fname)
    fname = cikm_en.test_url
    data_statistic(fname)