def evaluate(model, data, batch_size, log): was_training = model.training model.eval() pred = [] gold = [] with torch.no_grad(): for pa in data_generator(data, batch_size): x = pa[0] y = pa[1] slot = pa[2] _x, _y, p = model(x, y, slot, 'test') gold += _y pred += p if was_training: model.train() _gold = [] _pred = [] for i in gold: for j in i: _gold.append(j) for i in pred: for j in i: _pred.append(j) return conlleval.evaluate(_gold, _pred, log, verbose=True)
def train(config): dataDict = getNERdata(dataSetName=config.dataset, dataDir=config.data_dir, desc_path=config.description_path, cross_domain=config.cross_domain, exemplar_num=config.exemplar_num, target_domain=config.target_domain) emb, word2Idx = readTokenEmbeddings(config.embed_file) char2Idx = getCharIdx() label2Idx = ExtractLabelsFromTokens(dataDict['source']['train']) label2IdxForDev = ExtractLabelsFromTokens(dataDict['target']['dev']) label2IdxForTest = ExtractLabelsFromTokens(dataDict['target']['test']) print(label2IdxForDev) print(dataDict['exemplar_dev']) DevLabelEmbedding = Bilstm_LabelEmbedding.BuildLabelEmbedding( emb, word2Idx, label2IdxForDev, dataDict['description'], dataDict['exemplar_dev'], config.embedding_method, config.encoder_method, config.device) TestLabelEmbedding = Bilstm_LabelEmbedding.BuildLabelEmbedding( emb, word2Idx, label2IdxForTest, dataDict['description'], dataDict['exemplar_test'], config.embedding_method, config.encoder_method, config.device) max_batch_size = math.ceil( len(dataDict['source']['train']) / config.batch_size) model = Bilstm_LabelEmbedding(config, emb, word2Idx, label2Idx, char2Idx, dataDict['description'], dataDict['exemplar_train']) model.train() model = model.to(config.device) hist_valid_scores = [] patience = num_trial = 0 train_iter = 0 optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) train_time = time.time() config.save_dir = config.save_dir + config.target_domain + '/' if not os.path.exists(config.save_dir): os.mkdir(config.save_dir) if os.path.exists(os.path.join(config.save_dir, 'params')): os.remove(os.path.join(config.save_dir, 'params')) log = Logger(os.path.join(config.save_dir, '_src.txt'), level='info') for epoch in range(config.epoch): for da in data_generator(dataDict['source']['train'], config.batch_size): train_iter += 1 x = da[0] y = da[1] loss = model(x, y, 'train') optimizer.zero_grad() loss.backward() optimizer.step() if train_iter % config.log_every == 0: # print( # 'epoch %d, iter %d, loss %.2f, time elapsed %.2f sec' % # (epoch, train_iter, loss, time.time() - train_time), # file=sys.stderr) log.logger.info( 'epoch %d, iter %d, loss %.2f, time elapsed %.2f sec' % (epoch, train_iter, loss, time.time() - train_time)) train_time = time.time() if train_iter % config.log_valid == 0: trainLabelEmbedding = model.LabelEmbedding trainLabel2Idx = model.label2Idx model.label2Idx = label2IdxForDev model.LabelEmbedding = DevLabelEmbedding if config.crf: model.crf.labelembedding = model.crf.buildCRFLabelEmbedding( model.LabelEmbedding) model.crf.num_tags = model.LabelEmbedding.size(0) (valid_metric_pre, valid_metric_rec, valid_metric_f1), d = evaluate(model, dataDict['target']['dev'], config.batch_size, log) model.label2Idx = label2IdxForTest model.LabelEmbedding = TestLabelEmbedding if config.crf: model.crf.labelembedding = model.crf.buildCRFLabelEmbedding( model.LabelEmbedding) model.crf.num_tags = model.LabelEmbedding.size(0) (test_metric_pre, test_metric_rec, test_metric_f1), d = evaluate(model, dataDict['target']['test'], config.batch_size, log) model.label2Idx = label2Idx model.LabelEmbedding = trainLabelEmbedding if config.crf: model.crf.labelembedding = model.crf.buildCRFLabelEmbedding( model.LabelEmbedding) model.crf.num_tags = model.LabelEmbedding.size(0) # print("val_pre : %.4f, val_rec : %.4f, val_f1 : %.4f" % (valid_metric_pre, valid_metric_rec, valid_metric_f1), file=sys.stderr) # print("test_pre : %.4f, test_rec : %.4f, test_f1 : %.4f" % (test_metric_pre, test_metric_rec, test_metric_f1), file=sys.stderr) log.logger.info( "val_pre : %.4f, val_rec : %.4f, val_f1 : %.4f" % (valid_metric_pre, valid_metric_rec, valid_metric_f1)) log.logger.info( "test_pre : %.4f, test_rec : %.4f, test_f1 : %.4f" % (test_metric_pre, test_metric_rec, test_metric_f1)) is_better = len( hist_valid_scores ) == 0 or valid_metric_f1 > max(hist_valid_scores) hist_valid_scores.append(valid_metric_f1) if is_better: patience = 0 # print('save currently the best model to [%s]' % (config.save_dir + 'model'), file=sys.stderr) log.logger.info('save currently the best model to [%s]' % (config.save_dir + 'model')) model.save(config.save_dir + 'model') # also save the optimizers' state torch.save(optimizer.state_dict(), config.save_dir + 'optim') elif patience < config.patience: patience += 1 log.logger.info('hit patience %d' % patience) # print('hit patience %d' % patience, file=sys.stderr) if patience == int(config.patience): num_trial += 1 log.logger.info('hit #%d trial' % num_trial) # print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == config.max_num_trial: log.logger.info('early stop!') # print('early stop!', file=sys.stderr) exit(0) lr = optimizer.param_groups[0]['lr'] * config.lr_decay log.logger.info( 'load previously best model and decay learning rate to %f' % lr) # print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) # load model params = torch.load( config.save_dir + 'model', map_location=lambda storage, loc: storage) model.load_state_dict(params['state_dict']) model = model.to(config.device) log.logger.info('restore parameters of the optimizers') # print('restore parameters of the optimizers', file=sys.stderr) optimizer.load_state_dict( torch.load(config.save_dir + 'optim')) # set new lr for param_group in optimizer.param_groups: param_group['lr'] = lr # reset patience patience = 0