def main(my_arg): log_dir = 'ner_logs' + str(my_arg) logger = Logger(log_dir) emb = LoadEmbedding('res/embedding.txt') if config['label_emb'] or config['question_alone']: onto_emb = LoadEmbedding('res/onto_embedding.txt') print('finish loading embedding') # batch_getter = BatchGetter('data/train', 'GPE_NAM', config['batch_size']) batch_getter_lst = [] if config['bioes']: if config['data'] == 'conll': # pernam_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'PER', 1, True) pernam_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.train', 'PER', 1, True) batch_getter_lst.append(pernam_batch_getter) # loc_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'LOC', 1, True) loc_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.train', 'LOC', 1, True) batch_getter_lst.append(loc_batch_getter) if not config['drop_misc']: # misc_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'MISC', 1, True) misc_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.train', 'MISC', 1, True) batch_getter_lst.append(misc_batch_getter) # org_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.train', 'ORG', 1, True) org_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.train', 'ORG', 1, True) batch_getter_lst.append(org_batch_getter) elif config['data'] == 'OntoNotes': # onto_notes_data = TrainOntoNotesGetter('data/OntoNotes/train.json', 1, True) onto_notes_data = OntoNotesGetter( 'data/OntoNotes/leaf_train.json', ['/person', '/organization', '/location', '/other'], 1, True) batch_getter_lst.append(onto_notes_data) else: pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train', 'PER', 1, True) batch_getter_lst.append(pernam_batch_getter) loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train', 'LOC', 1, True) batch_getter_lst.append(loc_batch_getter) if not config['drop_misc']: misc_batch_getter = ConllBatchGetter( 'data/conll2003/bio_eng.train', 'MISC', 1, True) batch_getter_lst.append(misc_batch_getter) org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.train', 'ORG', 1, True) batch_getter_lst.append(org_batch_getter) batch_getter = MergeBatchGetter(batch_getter_lst, config['batch_size'], True, data_name=config['data']) print('finish loading train data') # if config['data'] == 'OntoNotes': # emb_onto = True # else: # emb_onto = False embedding_layer = EmbeddingLayer(emb) if config['label_emb']: q_word_embedding = nn.Embedding(onto_emb.get_voc_size(), onto_emb.get_emb_size()) q_word_embedding.weight.data.copy_(onto_emb.get_embedding_tensor()) q_word_embedding.weight.requires_grad = False else: q_word_embedding = None d = config['hidden_size'] if config['question_alone']: q_emb_layer = QLabel(onto_emb) else: q_emb_layer = None att_layer = AttentionFlowLayer(2 * d) model_layer = ModelingLayer(8 * d, d, 2) ner_hw_layer = NerHighway(2 * d, 8 * d, 1) ner_out_layer = NerOutLayer(10 * d, len(config['Tags'])) crf = CRF(config, config['Tags'], 10 * d) if config['USE_CUDA']: att_layer.cuda(config['cuda_num']) embedding_layer.cuda(config['cuda_num']) if config['label_emb']: q_word_embedding.cuda(config['cuda_num']) model_layer.cuda(config['cuda_num']) ner_hw_layer.cuda(config['cuda_num']) ner_out_layer.cuda(config['cuda_num']) crf.cuda(config['cuda_num']) if config['question_alone']: q_emb_layer.cuda(config['cuda_num']) squad_model_dir = 'mr_model1' if not config['not_pretrain']: att_layer.load_state_dict( torch.load(squad_model_dir + '/early_att_layer.pkl', map_location=lambda storage, loc: storage)) model_layer.load_state_dict( torch.load(squad_model_dir + '/early_model_layer.pkl', map_location=lambda storage, loc: storage)) embedding_layer.load_state_dict( torch.load(squad_model_dir + '/early_embedding_layer.pkl', map_location=lambda storage, loc: storage)) if config['freeze']: for param in att_layer.parameters(): param.requires_grad = False for param in model_layer.parameters(): param.requires_grad = False for param in embedding_layer.parameters(): param.requires_grad = False embedding_layer.eval() model_layer.eval() att_layer.eval() emb_opt = None att_opt = None model_opt = None else: if config['not_pretrain']: emb_opt = torch.optim.Adam( filter(lambda param: param.requires_grad, embedding_layer.parameters())) att_opt = torch.optim.Adam( filter(lambda param: param.requires_grad, att_layer.parameters())) model_opt = torch.optim.Adam( filter(lambda param: param.requires_grad, model_layer.parameters())) else: emb_opt = torch.optim.Adam(filter( lambda param: param.requires_grad, embedding_layer.parameters()), lr=1e-4) att_opt = torch.optim.Adam(filter( lambda param: param.requires_grad, att_layer.parameters()), lr=1e-4) model_opt = torch.optim.Adam(filter( lambda param: param.requires_grad, model_layer.parameters()), lr=1e-4) # model_opt = torch.optim.Adam(filter(lambda param: param.requires_grad, model_layer.parameters())) ner_hw_opt = torch.optim.Adam( filter(lambda param: param.requires_grad, ner_hw_layer.parameters())) ner_out_opt = torch.optim.Adam( filter(lambda param: param.requires_grad, ner_out_layer.parameters())) crf_opt = torch.optim.Adam( filter(lambda param: param.requires_grad, crf.parameters())) if config['question_alone']: q_emb_opt = torch.optim.Adam( filter(lambda param: param.requires_grad, q_emb_layer.parameters())) else: q_emb_opt = None log_file = open('{}/log_file'.format(log_dir), 'w') f_max = 0 low_epoch = 0 ex_iterations = 0 model_dir = 'ner_model' + str(my_arg) time0 = time.time() for epoch in range(config['max_epoch']): embedding_layer.train() att_layer.train() model_layer.train() ner_hw_layer.train() ner_out_layer.train() crf.train() if config['question_alone']: q_emb_layer.train() # f, p, r = evaluate_all(my_arg, False) for iteration, this_batch in enumerate(batch_getter): if (ex_iterations + iteration) % 100 == 0: print('epoch: {}, iteraton: {}'.format( epoch, ex_iterations + iteration)) train_iteration(logger, ex_iterations + iteration, embedding_layer, q_word_embedding, q_emb_layer, att_layer, model_layer, ner_hw_layer, ner_out_layer, crf, emb_opt, q_emb_opt, att_opt, model_opt, ner_hw_opt, ner_out_opt, crf_opt, this_batch) if (ex_iterations + iteration) % 100 == 0: time1 = time.time() print('this iteration time: ', time1 - time0, '\n') time0 = time1 if (ex_iterations + iteration) % config['save_freq'] == 0: torch.save(embedding_layer.state_dict(), model_dir + '/embedding_layer.pkl') torch.save(att_layer.state_dict(), model_dir + '/att_layer.pkl') torch.save(model_layer.state_dict(), model_dir + '/model_layer.pkl') torch.save(ner_hw_layer.state_dict(), model_dir + '/ner_hw_layer.pkl') torch.save(ner_out_layer.state_dict(), model_dir + '/ner_out_layer.pkl') torch.save(crf.state_dict(), model_dir + '/crf.pkl') if config['question_alone']: torch.save(q_emb_layer.state_dict(), model_dir + '/q_emb_layer.pkl') ex_iterations += iteration + 1 batch_getter.reset() config['use_dropout'] = False f, p, r = evaluate_all(my_arg, False) config['use_dropout'] = True log_file.write('epoch: {} f: {} p: {} r: {}\n'.format(epoch, f, p, r)) log_file.flush() if f >= f_max: f_max = f low_epoch = 0 os.system('cp {}/embedding_layer.pkl {}/early_embedding_layer.pkl'. format(model_dir, model_dir)) os.system('cp {}/att_layer.pkl {}/early_att_layer.pkl'.format( model_dir, model_dir)) os.system('cp {}/model_layer.pkl {}/early_model_layer.pkl'.format( model_dir, model_dir)) os.system( 'cp {}/ner_hw_layer.pkl {}/early_ner_hw_layer.pkl'.format( model_dir, model_dir)) os.system( 'cp {}/ner_out_layer.pkl {}/early_ner_out_layer.pkl'.format( model_dir, model_dir)) os.system('cp {}/crf.pkl {}/early_crf.pkl'.format( model_dir, model_dir)) if config['question_alone']: os.system( 'cp {}/q_emb_layer.pkl {}/early_q_emb_layer.pkl'.format( model_dir, model_dir)) else: low_epoch += 1 log_file.write('low' + str(low_epoch) + '\n') log_file.flush() if low_epoch >= config['early_stop']: break log_file.close()
def free_evaluate_all(my_arg, pr=True): emb = LoadEmbedding('res/emb.txt') if config['label_emb'] or config['question_alone']: onto_emb = LoadEmbedding('res/onto_embedding.txt') print('finish loading embedding') batch_getter_lst = [] if my_arg == 0: # pernam_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'PER', 1, False) # batch_getter_lst.append(pernam_batch_getter) # loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'LOC', 1, False) # batch_getter_lst.append(loc_batch_getter) misc_batch_getter = ConllBatchGetter( 'data/conll2003/bio_eng.testa', 'MISC', 1, False) batch_getter_lst.append(misc_batch_getter) # org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testa', 'ORG', 1, False) # batch_getter_lst.append(org_batch_getter) if my_arg == 1: # pernam_batch_getter = ConllBatchGetter('data/ttt', 'PER', 1, False) # batch_getter_lst.append(pernam_batch_getter) # pernam_batch_getter = ConllBatchGetter('data/ttt', 'singer', 1, False) # batch_getter_lst.append(pernam_batch_getter) pernam_batch_getter = ConllBatchGetter( 'data/conll2003/bio_eng.testb', 'PER', 1, False) batch_getter_lst.append(pernam_batch_getter) loc_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testb', 'LOC', 1, False) batch_getter_lst.append(loc_batch_getter) # misc_batch_getter = ConllBatchGetter( 'data/conll2003/bio_eng.testb', 'MISC', 1, False) batch_getter_lst.append(misc_batch_getter) org_batch_getter = ConllBatchGetter('data/conll2003/bio_eng.testb', 'ORG', 1, False) batch_getter_lst.append(org_batch_getter) if my_arg == 2: # pernam_batch_getter = ConllBatchGetter('data/conll2003/bioes_eng.testb', 'food', 1, False) # batch_getter_lst.append(pernam_batch_getter) pernam_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.testb', 'PER', 1, False) batch_getter_lst.append(pernam_batch_getter) loc_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.testb', 'LOC', 1, False) batch_getter_lst.append(loc_batch_getter) misc_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.testb', 'MISC', 1, False) batch_getter_lst.append(misc_batch_getter) org_batch_getter = ConllBatchGetter( 'data/conll2003/bioes_eng.testb', 'ORG', 1, False) batch_getter_lst.append(org_batch_getter) if my_arg == 3: # onto_notes = OntoNotesGetter('data/OntoNotes/test.json', '/person', 1, False) # batch_getter_lst.append(onto_notes) onto_notes_data = OntoNotesGetter('data/OntoNotes/test.json', utils.get_ontoNotes_type_lst(), 1, False) batch_getter_lst.append(onto_notes_data) batch_size = 100 batch_getter = MergeBatchGetter(batch_getter_lst, batch_size, False, data_name=config['data']) print('finish loading dev data') # if config['data'] == 'OntoNotes': # emb_onto = True # else: # emb_onto = False embedding_layer = EmbeddingLayer(emb, 0) if config['label_emb']: q_word_embedding = nn.Embedding(onto_emb.get_voc_size(), onto_emb.get_emb_size()) q_word_embedding.weight.data.copy_(onto_emb.get_embedding_tensor()) q_word_embedding.weight.requires_grad = False else: q_word_embedding = None d = config['hidden_size'] if config['question_alone']: q_emb_layer = QLabel(onto_emb, 0) else: q_emb_layer = None att_layer = AttentionFlowLayer(2 * d) model_layer = ModelingLayer(8 * d, d, 2, 0) ner_hw_layer = NerHighway(2 * d, 8 * d, 1) ner_out_layer = NerOutLayer(10 * d, len(config['Tags']), 0) crf = CRF(config, config['Tags'], len(config['Tags'])) if config['USE_CUDA']: att_layer.cuda(config['cuda_num']) embedding_layer.cuda(config['cuda_num']) if config['label_emb']: q_word_embedding.cuda(config['cuda_num']) model_layer.cuda(config['cuda_num']) ner_hw_layer.cuda(config['cuda_num']) ner_out_layer.cuda(config['cuda_num']) crf.cuda(config['cuda_num']) if config['question_alone']: q_emb_layer.cuda(config['cuda_num']) model_dir = 'ner_model8' att_layer.load_state_dict( torch.load(model_dir + '/early_att_layer.pkl', map_location=lambda storage, loc: storage)) model_layer.load_state_dict( torch.load(model_dir + '/early_model_layer.pkl', map_location=lambda storage, loc: storage)) ner_hw_layer.load_state_dict( torch.load(model_dir + '/early_ner_hw_layer.pkl', map_location=lambda storage, loc: storage)) ner_out_layer.load_state_dict( torch.load(model_dir + '/early_ner_out_layer.pkl', map_location=lambda storage, loc: storage)) crf.load_state_dict( torch.load(model_dir + '/early_crf.pkl', map_location=lambda storage, loc: storage)) embedding_layer.load_state_dict( torch.load(model_dir + '/early_embedding_layer.pkl', map_location=lambda storage, loc: storage)) if config['question_alone']: q_emb_layer.load_state_dict( torch.load(model_dir + '/q_emb_layer.pkl', map_location=lambda storage, loc: storage)) else: q_emb_layer = None if config['question_alone']: q_emb_layer.eval() embedding_layer.eval() att_layer.eval() model_layer.eval() ner_hw_layer.eval() ner_out_layer.eval() crf.eval() ner_tag = Vocab('res/ner_xx', unk_id=config['UNK_token'], pad_id=config['PAD_token']) if my_arg == 3: evaluator = ConllBoundaryPerformance(ner_tag, onto_notes_data) else: evaluator = ConllBoundaryPerformance(ner_tag) evaluator.reset() out_file = codecs.open('data/eva_result' + str(my_arg), mode='wb', encoding='utf-8') # writer.add_embedding(embedding_layer.word_embedding.weight.data.cpu()) # return all_emb = None all_metadata = [] ex_iterations = 0 summary_emb = False for iteration, this_batch in enumerate(batch_getter): # if iteration >= 15: # break if summary_emb: top_path, all_emb, all_metadata, q = evaluate_one( ex_iterations + iteration, embedding_layer, q_word_embedding, q_emb_layer, att_layer, model_layer, ner_hw_layer, ner_out_layer, crf, this_batch, summary_emb, all_emb, all_metadata) else: top_path = evaluate_one(ex_iterations + iteration, embedding_layer, q_word_embedding, q_emb_layer, att_layer, model_layer, ner_hw_layer, ner_out_layer, crf, this_batch) for batch_no, path in enumerate(top_path): evaluator.evaluate( iteration * batch_size + batch_no, remove_end_tag( this_batch[1].numpy()[batch_no, :].tolist()), path, out_file, pr) if (iteration + 1) * batch_size % 100 == 0: print('{} sentences processed'.format( (iteration + 1) * batch_size)) evaluator.get_performance() if summary_emb: writer.add_embedding(torch.cat([q, all_emb], 0), metadata=['question'] + all_metadata) return evaluator.get_performance()