Пример #1
0
 def do_eval(dmv_model, m_model, pos, languages, language_map, epoch,
             options):
     print "===================================="
     print 'Do evaluation'
     if not options.eval_new_language:
         eval_language_set = languages.keys()
         eval_languages = languages
     else:
         eval_language_set = utils.read_language_list(options.language_path)
         eval_languages = {l: i for i, l in enumerate(eval_language_set)}
     eval_file_list = os.listdir(options.dev)
     eval_file_set = utils.get_file_set(eval_file_list, eval_language_set,
                                        False)
     eval_sentences, eval_language_map = utils.read_multiple_data(
         options.dev, eval_file_set, True)
     dmv_model.eval()
     if options.use_neural:
         m_model.eval()
     devpath = os.path.join(
         options.output,
         'eval_pred' + str(epoch + 1) + '_' + str(options.sample_idx))
     eval_data_list, _, eval_sentence_map = utils.construct_ml_pos_data(
         eval_sentences, pos, eval_languages, eval_language_map)
     eval_batch_data = utils.construct_batch_data(eval_data_list,
                                                  options.batchsize)
     parse_results = {}
     classify_results = np.zeros(len(eval_data_list))
     if options.sentence_predict and epoch > options.non_neural_iter:
         eval_trans_param = np.zeros(
             (len(eval_data_list), len(pos.keys()), len(pos.keys()), 2,
              options.c_valency))
     else:
         eval_trans_param = None
     for batch_id, one_batch in enumerate(eval_batch_data):
         eval_batch_pos, eval_batch_lan, eval_batch_sen = [
             s[0] for s in one_batch
         ], [s[1] for s in one_batch], [s[2][0] for s in one_batch]
         eval_batch_sen = np.array(eval_batch_sen)
         eval_batch_lan = np.array(eval_batch_lan)
         eval_batch_pos = np.array(eval_batch_pos)
         if (options.sentence_predict and epoch > options.non_neural_iter
             ) or options.language_predict:
             batch_rule_samples = dmv_model.find_predict_samples(
                 eval_batch_pos, eval_batch_lan, eval_batch_sen)
             batch_predict_data = utils.construct_ml_predict_data(
                 batch_rule_samples)
             batch_predict_pos_v = torch.LongTensor(
                 batch_predict_data['pos'])
             batch_predict_pos_index = np.array(batch_predict_data['pos'])
             batch_predict_dir_v = torch.LongTensor(
                 batch_predict_data['dir'])
             batch_predict_dir_index = np.array(batch_predict_data['dir'])
             batch_predict_cvalency_v = torch.LongTensor(
                 batch_predict_data['cvalency'])
             batch_predict_cvalency_index = np.array(
                 batch_predict_data['cvalency'])
             batch_predict_lan_v = torch.LongTensor(
                 batch_predict_data['languages'])
             batch_predict_lan_index = np.array(
                 batch_predict_data['languages'])
             batch_predict_sen_v = []
             for sentence_id in batch_predict_data['sentence']:
                 batch_predict_sen_v.append(eval_sentence_map[sentence_id])
             batch_predict_sen_index = np.array(
                 batch_predict_data['sentence'])
             batch_predict_sen_v = torch.LongTensor(batch_predict_sen_v)
             batch_predicted, batch_predicted_lan = m_model.forward_(
                 batch_predict_pos_v, batch_predict_dir_v,
                 batch_predict_cvalency_v, None, None, True, 'child',
                 batch_predict_lan_v, batch_predict_sen_v, None)
             if options.sentence_predict or options.language_predict:
                 # Evaluation of language pediction
                 for i in range(len(batch_predict_sen_v)):
                     sentence_idx = batch_predict_data['sentence'][i]
                     classify_results[sentence_idx] = batch_predicted_lan[i]
                 if options.sentence_predict:
                     eval_trans_param[
                         batch_predict_sen_index,
                         batch_predict_pos_index, :,
                         batch_predict_dir_index,
                         batch_predict_cvalency_index] = batch_predicted.detach(
                         ).numpy()
             else:
                 eval_trans_param[
                     batch_predict_pos_index, :, batch_predict_dir_index,
                     batch_predict_cvalency_index,
                     batch_predict_lan_index] = batch_predicted.detach(
                     ).numpy()
         batch_score, batch_decision_score = dmv_model.evaluate_batch_score(
             eval_batch_pos, eval_batch_sen, eval_language_map,
             eval_languages, eval_trans_param)
         if options.function_mask:
             batch_score = dmv_model.function_to_mask(
                 batch_score, eval_batch_pos)
         batch_score = np.expand_dims(batch_score, 3)
         batch_score = np.expand_dims(batch_score, 4)
         batch_decision_score = np.expand_dims(batch_decision_score, 2)
         batch_parse = eisner_for_dmv.batch_parse(batch_score,
                                                  batch_decision_score,
                                                  dmv_model.dvalency,
                                                  dmv_model.cvalency)
         for i in range(len(eval_batch_pos)):
             parse_results[eval_batch_sen[i]] = (batch_parse[0][i],
                                                 batch_parse[1][i])
     utils.eval_ml(parse_results, eval_sentences, devpath,
                   options.log + '_dev' + str(options.sample_idx),
                   eval_language_map, eval_languages, epoch)
     # utils.write_distribution(dmv_model)
     print "===================================="
     # language classification results
     if not options.eval_new_language and (options.sentence_predict
                                           or options.language_predict):
         correct = 0
         for i in range(len(classify_results)):
             if classify_results[i] == languages[eval_language_map[i]]:
                 correct += 1
         correct_rate = float(correct) / len(classify_results)
         print "Language classification accuracy " + str(correct_rate)
Пример #2
0
    pos, sentences, languages, language_map = utils.read_multiple_data(
        options.train, file_set, False)
    sentence_language_map = {}
    if options.concat_all:
        languages = {'all': 0}
        for s in language_map.keys():
            language_map[s] = 'all'
    print 'Data read'
    with open(
            os.path.join(options.output,
                         options.params + '_' + str(options.sample_idx)),
            'w') as paramsfp:
        pickle.dump((pos, options), paramsfp)
    print 'Parameters saved'

    data_list, data_pos, sentence_map = utils.construct_ml_pos_data(
        sentences, pos, languages, language_map)
    batch_data = utils.construct_update_batch_data(data_list,
                                                   options.batchsize)
    print 'Batch data constructed'
    data_size = len(data_list)

    ml_dmv_model = MLDMV(pos, sentence_map, languages, language_map, data_size,
                         options)

    print 'Model constructed'

    ml_dmv_model.init_param(sentences)

    print 'Parameters initialized'
    if options.gpu >= 0 and torch.cuda.is_available():
        torch.cuda.set_device(options.gpu)