def test(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) if config.data_split == 'dev': with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) else: with open(config.test_eval_file, 'r') as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) if config.cuda: torch.cuda.manual_seed_all(config.seed) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') if config.data_split == 'dev': dev_buckets = get_buckets(config.dev_record_file) para_limit = config.para_limit ques_limit = config.ques_limit elif config.data_split == 'test': para_limit = None ques_limit = None dev_buckets = get_buckets(config.test_record_file) def build_dev_iterator(): dev_dataset = HotpotDataset(dev_buckets) return DataIterator(dev_dataset, config.para_limit, config.ques_limit, config.char_limit, config.sent_limit, batch_size=config.batch_size, num_workers=2) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) ori_model = model.cuda() if config.cuda else model ori_model.load_state_dict( torch.load(os.path.join(config.save, 'model.pt'), map_location=lambda storage, loc: storage)) # model = nn.DataParallel(ori_model) model = ori_model model.eval() predict(build_dev_iterator(), model, dev_eval_file, config, config.prediction_file)
def test(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) if config.data_split == 'dev': with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) else: with open(config.test_eval_file, 'r') as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') if config.data_split == 'dev': dev_buckets = get_buckets(config.dev_record_file) para_limit = config.para_limit ques_limit = config.ques_limit elif config.data_split == 'test': para_limit = None ques_limit = None dev_buckets = get_buckets(config.test_record_file) def build_dev_iterator(): return DataIterator(dev_buckets, config.batch_size, para_limit, ques_limit, config.char_limit, False, config.sent_limit) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) ori_model = model.cuda() ori_model.load_state_dict(torch.load(os.path.join(config.save, 'model.pt'))) model = nn.DataParallel(ori_model) model.eval() predict(build_dev_iterator(), model, dev_eval_file, config, config.prediction_file)
def train(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) if config.cuda: torch.cuda.manual_seed_all(config.seed) config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py']) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') logging('Config') for k, v in config.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging("Building model...") train_buckets = get_buckets(config.train_record_file) dev_buckets = get_buckets(config.dev_record_file) def build_train_iterator(): train_dataset = HotpotDataset(train_buckets) return DataIterator(train_dataset, config.para_limit, config.ques_limit, config.char_limit, config.sent_limit, batch_size=config.batch_size, sampler=RandomSampler(train_dataset), num_workers=2) def build_dev_iterator(): dev_dataset = HotpotDataset(dev_buckets) return DataIterator(dev_dataset, config.para_limit, config.ques_limit, config.char_limit, config.sent_limit, batch_size=config.batch_size, num_workers=2) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) logging('nparams {}'.format( sum([p.nelement() for p in model.parameters() if p.requires_grad]))) ori_model = model.cuda() if config.cuda else model model = nn.DataParallel(ori_model) lr = config.init_lr optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) cur_patience = 0 total_loss = 0 global_step = 0 best_dev_F1 = None stop_train = False start_time = time.time() eval_start_time = time.time() model.train() train_iterator = build_train_iterator() dev_iterator = build_dev_iterator() for epoch in range(10000): for data in train_iterator: if config.cuda: data = { k: (data[k].cuda() if k != 'ids' else data[k]) for k in data } context_idxs = data['context_idxs'] ques_idxs = data['ques_idxs'] context_char_idxs = data['context_char_idxs'] ques_char_idxs = data['ques_char_idxs'] context_lens = data['context_lens'] y1 = data['y1'] y2 = data['y2'] q_type = data['q_type'] is_support = data['is_support'] start_mapping = data['start_mapping'] end_mapping = data['end_mapping'] all_mapping = data['all_mapping'] logit1, logit2, predict_type, predict_support = model( context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, context_lens.sum(1).max().item(), return_yp=False) loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) loss = loss_1 + config.sp_lambda * loss_2 optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.max_grad_norm if config.max_grad_norm > 0 else 1e10) optimizer.step() total_loss += loss.item() global_step += 1 if global_step % config.period == 0: cur_loss = total_loss / config.period elapsed = time.time() - start_time logging( '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | gradnorm: {:6.3}' .format(epoch, global_step, lr, elapsed * 1000 / config.period, cur_loss, grad_norm)) total_loss = 0 start_time = time.time() if global_step % config.checkpoint == 0: model.eval() metrics = evaluate_batch(dev_iterator, model, 0, dev_eval_file, config) model.train() logging('-' * 89) logging( '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}' .format(global_step // config.checkpoint, epoch, time.time() - eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) logging('-' * 89) eval_start_time = time.time() dev_F1 = metrics['f1'] if best_dev_F1 is None or dev_F1 > best_dev_F1: best_dev_F1 = dev_F1 torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) cur_patience = 0 else: cur_patience += 1 if cur_patience >= config.patience: lr /= 2.0 for param_group in optimizer.param_groups: param_group['lr'] = lr if lr < config.init_lr * 1e-2: stop_train = True break cur_patience = 0 if stop_train: break logging('best_dev_F1 {}'.format(best_dev_F1))
def train(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) torch.cuda.manual_seed_all(config.seed) config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py']) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') logging('Config') for k, v in config.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging("Building model...") train_buckets = get_buckets(config.train_record_file) dev_buckets = get_buckets(config.dev_record_file) def build_train_iterator(): return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit) def build_dev_iterator(): return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) logging('nparams {}'.format( sum([p.nelement() for p in model.parameters() if p.requires_grad]))) #ori_model = model.cuda() #ori_model = model #model = nn.DataParallel(ori_model) model = model.cuda() lr = config.init_lr optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) cur_patience = 0 total_loss = 0 global_step = 0 best_dev_F1 = None stop_train = False start_time = time.time() eval_start_time = time.time() model.train() train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0} best_dev_sp_f1 = 0 # total_support_facts = 0 # total_contexes = 0 total_support_facts = 0 total_contexes = 0 for epoch in range(5): for data in build_train_iterator(): context_idxs = Variable(data['context_idxs']) ques_idxs = Variable(data['ques_idxs']) context_char_idxs = Variable(data['context_char_idxs']) ques_char_idxs = Variable(data['ques_char_idxs']) context_lens = Variable(data['context_lens']) y1 = Variable(data['y1']) y2 = Variable(data['y2']) q_type = Variable(data['q_type']) is_support = Variable(data['is_support']) start_mapping = Variable(data['start_mapping']) end_mapping = Variable(data['end_mapping']) all_mapping = Variable(data['all_mapping']) # print(all_mapping.size()) # total_support_facts += torch.sum(torch.sum(is_support)) # total_contexes += is_support.size(0) * is_support.size(1) # context_idxs : torch.Size([1, 1767]) # ques_idxs : torch.Size([1, 17]) # context_char_idxs : torch.Size([1, 1767, 16]) # ques_char_idxs : torch.Size([1, 17, 16]) # context_lens : tensor([1767]) # start_mapping : torch.Size([1, 1767, 65]) # end_mapping : torch.Size([1, 1767, 65]) # all_mapping : torch.Size([1, 1767, 65]) # continue # change the input format into Sentences (input) -> Is support fact (target) # get total number of sentences number_of_sentences = int( torch.sum(torch.sum(start_mapping, dim=1)).item()) #print('number_of_sentences=', number_of_sentences) # get sentence limit sentence_limit = config.sent_limit # get question limit ques_limit = config.ques_limit # get character limit char_limit = config.char_limit sent_limit = 600 # allocate space sentence_idxs = torch.zeros(number_of_sentences, sent_limit, dtype=torch.long).cuda() sentence_char_idxs = torch.zeros(number_of_sentences, sent_limit, char_limit, dtype=torch.long).cuda() sentence_lengths = torch.zeros(number_of_sentences, dtype=torch.long).cuda() is_support_fact = torch.zeros(number_of_sentences, dtype=torch.long).cuda() ques_idxs_sen_impl = torch.zeros(number_of_sentences, ques_limit, dtype=torch.long).cuda() ques_char_idxs_sen_impl = torch.zeros(number_of_sentences, ques_limit, char_limit, dtype=torch.long).cuda() # sentence_idxs = [] # sentence_char_idxs = [] # sentence_lengths = [] # is_support_fact = [] # ques_idxs_sen_impl = [] # ques_char_idxs_sen_impl = [] index = 0 # for every batch for b in range(all_mapping.size(0)): # for every sentence for i in range(all_mapping.size(2)): # get sentence map sentence_map = all_mapping[b, :, i] s = torch.sum(sentence_map) # if there are no more sentences on this batch (but only zero-pads) then continue to next batch if s == 0: continue # get sentence # get starting index starting_index = torch.argmax(start_mapping[b, :, i]) # get ending index ending_index = torch.argmax(end_mapping[b, :, i]) + 1 # get sentence length sentence_length = int(torch.sum(all_mapping[b, :, i])) sentence = context_idxs[b, starting_index:ending_index] sentence_chars = context_char_idxs[ b, starting_index:ending_index, :] #print('sentence=', sentence) #if sentence_length>100: # print('Sentence starts :', starting_index, ', & end :', ending_index, ', Total tokens sentence :', torch.sum(all_mapping[b,:,i]), 'sentence_length=', sentence_length, 'start mapping=',start_mapping[b,:,i], 'end mapping=', end_mapping[b,:,i]) # os.system("pause") sentence_idxs[index, :sentence_length] = sentence sentence_char_idxs[ index, :sentence_length, :] = sentence_chars sentence_lengths[index] = sentence_length is_support_fact[index] = is_support[b, i] # repeat for the question ques_idxs_sen_impl[index, :ques_idxs[ b, :].size(0)] = ques_idxs[b, :] ques_char_idxs_sen_impl[index, :ques_idxs[ b, :].size(0), :] = ques_char_idxs[b, :, :] # append to lists # sentence_idxs.append(sentence) # sentence_char_idxs.append(sentence_chars) # sentence_lengths.append(sentence_length) # is_support_fact.append(is_support[b,i]) # repeat for the question # ques_idxs_sen_impl.append(ques_idxs[b,:]) # ques_char_idxs_sen_impl.append(ques_char_idxs[b,:,:]) index += 1 # zero padd sentence_length = torch.max(sentence_lengths) # torch.Tensor() # for i in range(len(sentence_idxs)): # sentence_idxs = torch.stack(sentence_idxs) # sentence_char_idxs = torch.stack(sentence_char_idxs) # sentence_lengths = torch.stack(sentence_lengths) # is_support_fact = torch.stack(is_support_fact) # ques_idxs_sen_impl = torch.stack(ques_idxs_sen_impl) # ques_char_idxs_sen_impl = torch.stack(ques_char_idxs_sen_impl) # predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False) predict_support = model(sentence_idxs, ques_idxs_sen_impl, sentence_char_idxs, ques_char_idxs_sen_impl, sentence_length, start_mapping, end_mapping) # logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False) # loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) #loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) #print('predict_support sz=',predict_support.size(), 'is_support_fact sz=', is_support_fact.size(), 'is_support_fact.unsqueeze=', is_support_fact.unsqueeze(1).size()) loss_2 = nll_average(predict_support.contiguous(), is_support_fact.contiguous()) # loss = loss_1 + config.sp_lambda * loss_2 loss = loss_2 # update train metrics # train_metrics = update_sp(train_metrics, predict_support.view(-1, 2), is_support.view(-1)) train_metrics = update_sp(train_metrics, predict_support, is_support_fact) #exit() # ps = predict_support.view(-1, 2) # iss = is_support.view(-1) # print('Predicted SP output and ground truth:') # length = predict_support.view(-1, 2).shape[0] # for jj in range(length): # print(ps[jj,1] , ' : ', iss[jj]) # temp = torch.cat([ predict_support.view(-1, 2).float(), is_support.view(-1)], dim=-1).contiguous() # print(temp) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.data.item() global_step += 1 if global_step % config.period == 0: # avegage metrics for key in train_metrics: train_metrics[key] /= float(config.period) # cur_loss = total_loss / config.period cur_loss = total_loss elapsed = time.time() - start_time logging( '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}' .format(epoch, global_step, lr, elapsed * 1000 / config.period, cur_loss, train_metrics['sp_em'], train_metrics['sp_f1'], train_metrics['sp_prec'], train_metrics['sp_recall'])) total_loss = 0 train_metrics = { 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0 } start_time = time.time() if global_step % config.checkpoint == 0: model.eval() # metrics = evaluate_batch(build_dev_iterator(), model, 5, dev_eval_file, config) eval_metrics = evaluate_batch(build_dev_iterator(), model, 500, dev_eval_file, config) model.train() logging('-' * 89) # logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint, # epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) logging( '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f}| SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}' .format(global_step // config.checkpoint, epoch, time.time() - eval_start_time, eval_metrics['loss'], eval_metrics['sp_em'], eval_metrics['sp_f1'], eval_metrics['sp_prec'], eval_metrics['sp_recall'])) logging('-' * 89) if eval_metrics['sp_f1'] > best_dev_sp_f1: best_dev_sp_f1 = eval_metrics['sp_f1'] torch.save(model.state_dict(), os.path.join(config.save, 'model.pt')) cur_patience = 0 else: cur_patience += 1 if cur_patience >= config.patience: lr *= 0.75 for param_group in optimizer.param_groups: param_group['lr'] = lr if lr < config.init_lr * 1e-2: stop_train = True break cur_patience = 0 eval_start_time = time.time() total_support_facts += torch.sum(torch.sum(is_support)) total_contexes += is_support.size(0) * is_support.size(0) # print('total_support_facts :', total_support_facts) # print('total_contexes :', total_contexes) # exit() if stop_train: break logging('best_dev_F1 {}'.format(best_dev_sp_f1))
def train(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) torch.cuda.manual_seed_all(config.seed) config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py']) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') logging('Config') for k, v in config.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging("Building model...") train_buckets = get_buckets(config.train_record_file) dev_buckets = get_buckets(config.dev_record_file) def build_train_iterator(): return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit) def build_dev_iterator(): return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) logging('nparams {}'.format(sum([p.nelement() for p in model.parameters() if p.requires_grad]))) ori_model = model.cuda() model = nn.DataParallel(ori_model) lr = config.init_lr # when function is defined (element for element in iterable if function(element)) # when function is None (element for element in iterable if element) # lamda 为一个function argument 为p 返回值为p.requires_grad # filter(lambda p: p.requires_grad, model.parameters) 为: 筛选出model.parameters 中所有requires_grad的元素 optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) cur_patience = 0 total_loss = 0 global_step = 0 best_dev_F1 = None stop_train = False start_time = time.time() eval_start_time = time.time() # 声明training: for drop_out 层 model.train() for epoch in range(10000): # 每一个data 是一个batch for idx, data in enumerate(build_train_iterator()): context_idxs = Variable(data['context_idxs']) ques_idxs = Variable(data['ques_idxs']) context_char_idxs = Variable(data['context_char_idxs']) ques_char_idxs = Variable(data['ques_char_idxs']) context_lens = Variable(data['context_lens']) y1 = Variable(data['y1']) y2 = Variable(data['y2']) q_type = Variable(data['q_type']) is_support = Variable(data['is_support']) is_support_word= Variable(data['is_support_word']) start_mapping = Variable(data['start_mapping']) end_mapping = Variable(data['end_mapping']) all_mapping = Variable(data['all_mapping']) logit1, logit2, predict_type, predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, is_support_word,return_yp=False) loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) loss = loss_1 + config.sp_lambda * loss_2 # optimizer.zero_grad() loss.backward() # optimizer.step() if (i + 1) % 2 == 0: optimizer.step() optimizer.zero_grad() total_loss += loss.data.item() global_step += 1 # 记录损失 if global_step % config.period == 0: cur_loss = total_loss / config.period elapsed = time.time() - start_time logging('| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}'.format(epoch, global_step, lr, elapsed*1000/config.period, cur_loss)) total_loss = 0 start_time = time.time() # 存入checkpoint if global_step % config.checkpoint == 0: # 设置model 为eval 状态 model.eval() # 返回值是一个dict: f1 metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config) model.train() logging('-' * 89) logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint, epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) logging('-' * 89) eval_start_time = time.time() dev_F1 = metrics['f1'] if best_dev_F1 is None or dev_F1 > best_dev_F1: best_dev_F1 = dev_F1 torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) cur_patience = 0 else: cur_patience += 1 if cur_patience >= config.patience: lr /= 2.0 for param_group in optimizer.param_groups: param_group['lr'] = lr if lr < config.init_lr * 1e-2: stop_train = True break cur_patience = 0 if stop_train: break logging('best_dev_F1 {}'.format(best_dev_F1))
def train(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) torch.cuda.manual_seed_all(config.seed) config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py']) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') logging('Config') for k, v in config.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging("Building model...") train_buckets = get_buckets(config.train_record_file) train_buckets_squad = get_buckets('squad_record.pkl') #yxh dev_buckets = get_buckets(config.dev_record_file) def build_train_iterator(buckets=train_buckets, batch_size=config.batch_size): return DataIterator(train_buckets, batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit) def build_dev_iterator(): return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) logging('nparams {}'.format( sum([p.nelement() for p in model.parameters() if p.requires_grad]))) ori_model = model.cuda() #ori_model.load_state_dict(torch.load(os.path.join('HOTPOT-20191113-222741', 'model.pt')))#yxh model = nn.DataParallel(ori_model) lr = config.init_lr optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) cur_patience = 0 total_loss = 0 global_step = 0 best_dev_F1 = None best_loss = None #yxh stop_train = False start_time = time.time() eval_start_time = time.time() model.train() for epoch in range(10000): #it2 = build_train_iterator(train_buckets_squad, 12) for data in build_train_iterator(batch_size=24): context_idxs = Variable(data['context_idxs']) ques_idxs = Variable(data['ques_idxs']) context_char_idxs = Variable(data['context_char_idxs']) ques_char_idxs = Variable(data['ques_char_idxs']) context_lens = Variable(data['context_lens']) y1 = Variable(data['y1']) y2 = Variable(data['y2']) q_type = Variable(data['q_type']) is_support = Variable(data['is_support']) start_mapping = Variable(data['start_mapping']) end_mapping = Variable(data['end_mapping']) all_mapping = Variable(data['all_mapping']) logit1, logit2, predict_type, predict_support = model( context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, is_support=is_support, return_yp=False) loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) loss = loss_1 + config.sp_lambda * loss_2 #loss = config.sp_lambda * nll_average(predict_support.view(-1, 2), is_support.view(-1))#yxh ### ''' try: data2 = next(it2) context_idxs2 = Variable(data2['context_idxs']) ques_idxs2 = Variable(data2['ques_idxs']) context_char_idxs2 = Variable(data2['context_char_idxs']) ques_char_idxs2 = Variable(data2['ques_char_idxs']) context_lens2 = Variable(data2['context_lens']) y12 = Variable(data2['y1']) y22 = Variable(data2['y2']) q_type2 = Variable(data2['q_type']) is_support2 = Variable(data2['is_support']) start_mapping2 = Variable(data2['start_mapping']) end_mapping2 = Variable(data2['end_mapping']) all_mapping2 = Variable(data2['all_mapping']) logit12, logit22, predict_type2, predict_support2 = model(context_idxs2, ques_idxs2, context_char_idxs2, ques_char_idxs2, context_lens2, start_mapping2, end_mapping2, all_mapping2, return_yp=False) #loss_12 = (nll_sum(predict_type2, q_type2) + nll_sum(logit12, y12) + nll_sum(logit22, y22)) / context_idxs.size(0) loss_22 = nll_average(predict_support2.view(-1, 2), is_support2.view(-1)) loss2 = config.sp_lambda * loss_22 loss = loss+loss2 except: pass ''' ### optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.data #[0] global_step += 1 if global_step % config.period == 0: cur_loss = total_loss / config.period elapsed = time.time() - start_time logging( '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}' .format(epoch, global_step, lr, elapsed * 1000 / config.period, cur_loss)) total_loss = 0 start_time = time.time() if global_step % config.checkpoint == 0: model.eval() metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config) model.train() logging('-' * 89) logging( '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}' .format(global_step // config.checkpoint, epoch, time.time() - eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) logging('-' * 89) eval_start_time = time.time() dev_F1 = metrics['f1'] if best_dev_F1 is None or dev_F1 > best_dev_F1: best_dev_F1 = dev_F1 torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) cur_patience = 0 else: cur_patience += 1 if cur_patience >= config.patience: lr /= 2.0 for param_group in optimizer.param_groups: param_group['lr'] = lr if lr < config.init_lr * 1e-2: stop_train = True break cur_patience = 0 #break#yxh if stop_train: break logging('best_dev_F1 {}'.format(best_dev_F1))
def train(config): experiment = Experiment(api_key="Q8LzfxMlAfA3ABWwq9fJDoR6r", project_name="hotpot", workspace="fan-luo") experiment.set_name(config.run_name) with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py']) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') logging('Config') for k, v in config.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging("Building model...") train_buckets = get_buckets(config.train_record_file) dev_buckets = get_buckets(config.dev_record_file) def build_train_iterator(): return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit) def build_dev_iterator(): return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) logging('nparams {}'.format( sum([p.nelement() for p in model.parameters() if p.requires_grad]))) ori_model = model.cuda() model = nn.DataParallel(ori_model) print("next(model.parameters()).is_cuda: " + str(next(model.parameters()).is_cuda)) print("next(ori_model.parameters()).is_cuda: " + str(next(ori_model.parameters()).is_cuda)) lr = config.init_lr optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) cur_patience = 0 total_loss = 0 total_ans_loss = 0 total_sp_loss = 0 global_step = 0 best_dev_F1 = None stop_train = False start_time = time.time() eval_start_time = time.time() model.train() for epoch in range(10000): for data in build_train_iterator(): context_idxs = Variable(data['context_idxs']) ques_idxs = Variable(data['ques_idxs']) context_char_idxs = Variable(data['context_char_idxs']) ques_char_idxs = Variable(data['ques_char_idxs']) context_lens = Variable(data['context_lens']) y1 = Variable(data['y1']) y2 = Variable(data['y2']) q_type = Variable(data['q_type']) is_support = Variable(data['is_support']) start_mapping = Variable(data['start_mapping']) end_mapping = Variable(data['end_mapping']) all_mapping = Variable(data['all_mapping']) logit1, logit2, predict_type, predict_support = model( context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False) loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) loss = loss_1 + config.sp_lambda * loss_2 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.data[0] total_ans_loss += loss_1.data[0] total_sp_loss += loss_2.data[0] global_step += 1 if global_step % config.period == 0: cur_loss = total_loss / config.period cur_ans_loss = total_ans_loss / config.period cur_sp_loss = total_sp_loss / config.period elapsed = time.time() - start_time logging( '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f} | answer loss {:8.3f} | supporting facts loss {:8.3f} ' .format(epoch, global_step, lr, elapsed * 1000 / config.period, cur_loss, cur_ans_loss, cur_sp_loss)) experiment.log_metrics( { 'train loss': cur_loss, 'train answer loss': cur_ans_loss, 'train supporting facts loss': cur_sp_loss }, step=global_step) total_loss = 0 total_ans_loss = 0 total_sp_loss = 0 start_time = time.time() if global_step % config.checkpoint == 0: model.eval() metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config) model.train() logging('-' * 89) logging( '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | answer loss {:8.3f} | supporting facts loss {:8.3f} | EM {:.4f} | F1 {:.4f}' .format(global_step // config.checkpoint, epoch, time.time() - eval_start_time, metrics['loss'], metrics['ans_loss'], metrics['sp_loss'], metrics['exact_match'], metrics['f1'])) logging('-' * 89) experiment.log_metrics( { 'dev loss': metrics['loss'], 'dev answer loss': metrics['ans_loss'], 'dev supporting facts loss': metrics['sp_loss'], 'EM': metrics['exact_match'], 'F1': metrics['f1'] }, step=global_step) eval_start_time = time.time() dev_F1 = metrics['f1'] if best_dev_F1 is None or dev_F1 > best_dev_F1: best_dev_F1 = dev_F1 torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) cur_patience = 0 else: cur_patience += 1 if cur_patience >= config.patience: lr /= 2.0 for param_group in optimizer.param_groups: param_group['lr'] = lr if lr < config.init_lr * 1e-2: stop_train = True break cur_patience = 0 if stop_train: break logging('best_dev_F1 {}'.format(best_dev_F1))
def test(config): # Inference mode (after training) if config.bert: if config.bert_with_glove: with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) else: word_mat = None else: with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) if config.data_split == 'dev': with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) else: with open(config.test_eval_file, 'r') as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) torch.cuda.manual_seed_all(config.seed) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') if config.data_split == 'dev': dev_buckets = get_buckets( config.dev_record_file, os.path.join(config.bert_dir, config.dev_example_ids)) if config.bert: dev_context_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.dev_bert_emb_context)) dev_ques_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.dev_bert_emb_ques)) para_limit = config.para_limit ques_limit = config.ques_limit elif config.data_split == 'test': para_limit = None ques_limit = None dev_buckets = get_buckets( config.test_record_file, os.path.join(config.bert_dir, config.test_example_ids)) if config.bert: dev_context_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.test_bert_emb_context)) dev_ques_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.test_bert_emb_ques)) def build_dev_iterator(): if config.bert: return DataIterator(dev_buckets, config.batch_size, para_limit, ques_limit, config.char_limit, False, config.sent_limit, bert=True, bert_buckets=list( zip(dev_context_buckets, dev_ques_buckets))) else: return DataIterator(dev_buckets, config.batch_size, para_limit, ques_limit, config.char_limit, False, config.sent_limit, bert=False) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) ori_model = model.cuda() ori_model.load_state_dict(torch.load(os.path.join(config.save, 'model.pt'))) model = nn.DataParallel(ori_model) model.eval() predict(build_dev_iterator(), model, dev_eval_file, config, config.prediction_file)
def train(config): if config.bert: if config.bert_with_glove: with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) else: word_mat = None else: with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) torch.cuda.manual_seed_all(config.seed) config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(config.save, scripts_to_save=[ 'run.py', 'model.py', 'util.py', 'sp_model.py', 'macnet_v2.py' ]) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') logging('Config') for k, v in config.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging("Building model...") train_buckets = get_buckets(config.train_record_file) dev_buckets = get_buckets(config.dev_record_file) max_iter = math.ceil(len(train_buckets[0]) / config.batch_size) def build_train_iterator(): if config.bert: train_context_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.train_bert_emb_context)) train_ques_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.train_bert_emb_ques)) return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit, bert=True, bert_buckets=list( zip(train_context_buckets, train_ques_buckets)), new_spans=True) else: return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit, bert=False, new_spans=True) def build_dev_iterator(): # Iterator for inference during training if config.bert: dev_context_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.dev_bert_emb_context)) dev_ques_buckets = get_buckets_bert( os.path.join(config.bert_dir, config.dev_bert_emb_ques)) return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit, bert=True, bert_buckets=list( zip(dev_context_buckets, dev_ques_buckets))) else: return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit, bert=False) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) logging('nparams {}'.format( sum([p.nelement() for p in model.parameters() if p.requires_grad]))) ori_model = model.cuda() model = nn.DataParallel(ori_model) lr = config.init_lr if config.optim == 'sgd': optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) elif config.optim == 'adam': optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) if config.scheduler == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epoch * max_iter, eta_min=0.0001) elif config.scheduler == 'plateau': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.5, patience=config.patience, verbose=True) cur_patience = 0 total_loss = 0 global_step = 0 best_dev_F1 = None stop_train = False start_time = time.time() eval_start_time = time.time() model.train() for epoch in range(config.epoch): for data in build_train_iterator(): scheduler.step() context_idxs = Variable(data['context_idxs']) ques_idxs = Variable(data['ques_idxs']) context_char_idxs = Variable(data['context_char_idxs']) ques_char_idxs = Variable(data['ques_char_idxs']) context_lens = Variable(data['context_lens']) y1 = Variable(data['y1']) y2 = Variable(data['y2']) is_y1 = Variable(data['is_y1']) is_y2 = Variable(data['is_y2']) q_type = Variable(data['q_type']) is_support = Variable(data['is_support']) start_mapping = Variable(data['start_mapping']) end_mapping = Variable(data['end_mapping']) all_mapping = Variable(data['all_mapping']) if config.bert: bert_context = Variable(data['bert_context']) bert_ques = Variable(data['bert_ques']) else: bert_context = None bert_ques = None support_ids = (is_support == 1).nonzero() sp_dict = {idx: [] for idx in range(context_idxs.shape[0])} for row in support_ids: bid, sp_sent_id = row sp_dict[bid.data.cpu()[0]].append(sp_sent_id.data.cpu()[0]) if config.sp_shuffle: [random.shuffle(value) for value in sp_dict.values()] sp1_labels = [] sp2_labels = [] sp3_labels = [] sp4_labels = [] sp5_labels = [] for item in sorted(sp_dict.items(), key=lambda t: t[0]): bid, supports = item if len(supports) == 1: sp1_labels.append(supports[0]) sp2_labels.append(-1) sp3_labels.append(-2) sp4_labels.append(-2) sp5_labels.append(-2) elif len(supports) == 2: sp1_labels.append(supports[0]) sp2_labels.append(supports[1]) sp3_labels.append(-1) sp4_labels.append(-2) sp5_labels.append(-2) elif len(supports) == 3: sp1_labels.append(supports[0]) sp2_labels.append(supports[1]) sp3_labels.append(supports[2]) sp4_labels.append(-1) sp5_labels.append(-2) elif len( supports) >= 4: # 4 or greater sp are treated the same sp1_labels.append(supports[0]) sp2_labels.append(supports[1]) sp3_labels.append(supports[2]) sp4_labels.append(supports[3]) sp5_labels.append(-1) # We will append 2 vectors to the front (sp_output_with_end), so we increment indices by 2 sp1_labels = np.array(sp1_labels) + 2 sp2_labels = np.array(sp2_labels) + 2 sp3_labels = np.array(sp3_labels) + 2 sp4_labels = np.array(sp4_labels) + 2 sp5_labels = np.array(sp5_labels) + 2 sp_labels_mod = Variable( torch.LongTensor( np.stack([ sp1_labels, sp2_labels, sp3_labels, sp4_labels, sp5_labels ], 1)).cuda()) sp_mod_mask = sp_labels_mod > 0 if config.bert: logit1, logit2, grn_logit1, grn_logit2, coarse_att, predict_type, predict_support, sp_att_logits = model( context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False, support_labels=is_support, is_train=True, sp_labels_mod=sp_labels_mod, bert=True, bert_context=bert_context, bert_ques=bert_ques) else: logit1, logit2, grn_logit1, grn_logit2, coarse_att, predict_type, predict_support, sp_att_logits = model( context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False, support_labels=is_support, is_train=True, sp_labels_mod=sp_labels_mod, bert=False) sp_att1 = sp_att_logits[0].squeeze(-1) sp_att2 = sp_att_logits[1].squeeze(-1) sp_att3 = sp_att_logits[2].squeeze(-1) if config.reasoning_steps == 4: sp_att4 = sp_att_logits[3].squeeze(-1) # Add masks to targets(labels) to ignore in loss calculation sp1_labels = ((sp_mod_mask[:, 0].float() - 1) * 100).cpu().data.numpy() + sp1_labels sp2_labels = ((sp_mod_mask[:, 1].float() - 1) * 100).cpu().data.numpy() + sp2_labels sp3_labels = ((sp_mod_mask[:, 2].float() - 1) * 100).cpu().data.numpy() + sp3_labels sp4_labels = ((sp_mod_mask[:, 3].float() - 1) * 100).cpu().data.numpy() + sp4_labels sp1_loss = nll_average( sp_att1, Variable(torch.LongTensor(sp1_labels).cuda())) sp2_loss = nll_average( sp_att2, Variable(torch.LongTensor(sp2_labels).cuda())) sp3_loss = nll_average( sp_att3, Variable(torch.LongTensor(sp3_labels).cuda())) if config.reasoning_steps == 4: sp4_loss = nll_average( sp_att4, Variable(torch.LongTensor(sp4_labels).cuda())) else: sp4_loss = 0 GRN_SP_PRED = True # temporary flag batch_losses = [] for bid, spans in enumerate(data['gold_mention_spans']): if (not spans == None) and (not spans == []): try: _loss_start_avg = F.cross_entropy( logit1[bid].view(1, -1).expand( len(spans), len(logit1[bid])), Variable(torch.LongTensor(spans))[:, 0].cuda(), reduce=True) _loss_end_avg = F.cross_entropy( logit2[bid].view(1, -1).expand( len(spans), len(logit2[bid])), Variable(torch.LongTensor(spans))[:, 1].cuda(), reduce=True) except IndexError: ipdb.set_trace() batch_losses.append((_loss_start_avg + _loss_end_avg)) loss_1 = torch.mean(torch.stack(batch_losses)) + nll_sum( predict_type, q_type) / context_idxs.size(0) if GRN_SP_PRED: loss_2 = 0 else: loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) loss = loss_1 + config.sp_lambda * loss_2 + sp1_loss + sp2_loss + sp3_loss + sp4_loss optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.data[0] global_step += 1 if global_step % config.period == 0: cur_loss = total_loss / config.period elapsed = time.time() - start_time logging( '| epoch {:3d} | step {:6d} | lr {:05.5f} | ms/batch {:5.2f} | train loss {:8.3f}' .format(epoch, global_step, optimizer.param_groups[0]['lr'], elapsed * 1000 / config.period, cur_loss)) total_loss = 0 start_time = time.time() if global_step % config.checkpoint == 0: model.eval() metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config) model.train() logging('-' * 89) logging( '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}' .format(global_step // config.checkpoint, epoch, time.time() - eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) logging('-' * 89) eval_start_time = time.time() dev_F1 = metrics['f1'] if config.scheduler == 'plateau': scheduler.step(dev_F1) if best_dev_F1 is None or dev_F1 > best_dev_F1: best_dev_F1 = dev_F1 torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) cur_patience = 0 if stop_train: break logging('best_dev_F1 {}'.format(best_dev_F1))
def train(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) torch.cuda.manual_seed_all(config.seed) config.save = '{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( config.save, scripts_to_save=['run.py', 'model.py', 'util.py', 'sp_model.py']) def logging(s, print_=True, log_=True): if print_: print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') logging('Config') for k, v in config.__dict__.items(): logging(' - {} : {}'.format(k, v)) logging("Building model...") train_buckets = get_buckets(config.train_record_file) dev_buckets = get_buckets(config.dev_record_file) # train_buckets = dev_buckets def build_train_iterator(): return DataIterator(train_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, True, config.sent_limit) def build_dev_iterator(): return DataIterator(dev_buckets, config.batch_size, config.para_limit, config.ques_limit, config.char_limit, False, config.sent_limit) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) logging('nparams {}'.format( sum([p.nelement() for p in model.parameters() if p.requires_grad]))) ori_model = model.cuda() # ori_model = model.cpu() # flag (checking if the learning will be loaded from the file or not) # lr = 0 optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters())) # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=config.init_lr) # if the training was interrupted, then we load last trained state of the model logging('Checking if previous training was interrupted...') my_file = Path('temp_model') if my_file.exists(): logging( 'Previous training was interupted, loading model and optimizer state' ) ori_model.load_state_dict(torch.load('temp_model')) optimizer.load_state_dict(torch.load('temp_optimizer')) # training_state = training_state.load_whole_class('last_training_state.pickle') # training_state = training_state.load_whole_class('last_training_state.pickle') # # if os.path.exists('last_model.pt'): # for dp, dn, filenames in os.walk('.'): # model_filenames = [] # corresponding_learning_rates = [] # for ff in filenames: # if ff.endswith("last_model.pt"): # # putting all found models on list # lr = float(ff.split('_')[0]) # model_filenames.append(ff) # corresponding_learning_rates.append(lr) # if len( model_filenames) > 0: # # selecting the model with the smallest learning rate to be loaded # loading_model_index = np.argmin(corresponding_learning_rates) # # continuing with the previous learning rate # lr = corresponding_learning_rates[loading_model_index] # logging('Previous training was interrupted so loading last saved state model and continuing.') # logging('Was stopped with learning rate: ' + str( corresponding_learning_rates[loading_model_index] ) ) # logging('Loading file : ' + model_filenames[loading_model_index]) # ori_model.load_state_dict(torch.load(model_filenames[loading_model_index])) # model = nn.DataParallel(ori_model) # if the learning rate was not loaded then we set it equal to the initial one # if lr == 0: # lr = config.init_lr cur_patience = 0 total_loss = 0 global_step = 0 best_dev_F1 = None stop_train = False start_time = time.time() eval_start_time = time.time() model.train() train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0} best_dev_sp_f1 = 0 # total_support_facts = 0 # total_contexes = 0 for epoch in range(1000): for data in build_train_iterator(): context_idxs = Variable(data['context_idxs']) ques_idxs = Variable(data['ques_idxs']) context_char_idxs = Variable(data['context_char_idxs']) ques_char_idxs = Variable(data['ques_char_idxs']) context_lens = Variable(data['context_lens']) y1 = Variable(data['y1']) y2 = Variable(data['y2']) q_type = Variable(data['q_type']) is_support = Variable(data['is_support']) start_mapping = Variable(data['start_mapping']) end_mapping = Variable(data['end_mapping']) all_mapping = Variable(data['all_mapping']) sentence_scores = Variable(data['sentence_scores']) # get model's output # predict_support = model(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False) # calculating loss # loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1) ) # logging('Batch training loss :' + str(loss_2.item()) ) # loss = loss_2 logit1, logit2, predict_type, predict_support = model( context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, sentence_scores, return_yp=False) # print('predict_type :', predict_type) # print(' q_type:',q_type ) # print(' logit1:', logit1) # print('y1 :',y1 ) # print(' logit2:', logit2) # print(' y2:', y2) # print(' predict_support:', predict_support) # print(' is_support:', is_support) loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1).long()) loss = loss_1 + config.sp_lambda * loss_2 # log both losses : # with open(os.path.join(config.save, 'losses_log.txt'), 'a+') as f_log: # s = 'Answer loss :' + str(loss_1.data.item()) + ', SF loss :' + str(loss_2.data.item()) + ', Total loss :' + str(loss.data.item()) # f_log.write(s + '\n') # update train metrics # train_metrics = update_sp(train_metrics, predict_support, is_support) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.data.item() global_step += 1 if global_step % config.period == 0: # avegage metrics for key in train_metrics: train_metrics[key] /= float(config.period) # cur_loss = total_loss / config.period cur_loss = total_loss elapsed = time.time() - start_time logging( '| epoch {:3d} | step {:6d} | ms/batch {:5.2f} | train loss {:8.3f}' .format(epoch, global_step, elapsed * 1000 / config.period, cur_loss)) # logging('| epoch {:3d} | step {:6d} | ms/batch {:5.2f} | train loss {:8.3f} | SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'.format(epoch, global_step, elapsed*1000/config.period, cur_loss, train_metrics['sp_em'], train_metrics['sp_f1'], train_metrics['sp_prec'], train_metrics['sp_recall'])) total_loss = 0 # train_metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0} start_time = time.time() # logging('Saving model...') torch.save(ori_model.state_dict(), os.path.join('temp_model')) torch.save(optimizer.state_dict(), os.path.join('temp_optimizer')) # if global_step % config.checkpoint == 0: # model.eval() # # metrics = evaluate_batch(build_dev_iterator(), model, 5, dev_eval_file, config) # eval_metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config) # model.train() # logging('-' * 89) # # logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format(global_step//config.checkpoint, # # epoch, time.time()-eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) # logging('| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f}| SP EM {:8.3f} | SP f1 {:8.3f} | SP Prec {:8.3f} | SP Recall {:8.3f}'.format(global_step//config.checkpoint, # epoch, time.time()-eval_start_time, eval_metrics['loss'], eval_metrics['sp_em'], eval_metrics['sp_f1'], eval_metrics['sp_prec'], eval_metrics['sp_recall'])) # logging('-' * 89) # eval_start_time = time.time() if global_step % config.checkpoint == 0: model.eval() metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file, config) model.train() logging('-' * 89) logging( '| eval {:6d} in epoch {:3d} | time: {:5.2f}s | dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}' .format(global_step // config.checkpoint, epoch, time.time() - eval_start_time, metrics['loss'], metrics['exact_match'], metrics['f1'])) logging('-' * 89) eval_start_time = time.time() dev_F1 = metrics['f1'] if best_dev_F1 is None or dev_F1 > best_dev_F1: best_dev_F1 = dev_F1 torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) torch.save(optimizer.state_dict(), os.path.join(config.save, 'optimizer.pt')) cur_patience = 0 else: cur_patience += 1 if cur_patience >= config.patience: stop_train = True break # if eval_metrics['sp_f1'] > best_dev_sp_f1: # best_dev_sp_f1 = eval_metrics['sp_f1'] # torch.save(ori_model.state_dict(), os.path.join(config.save, 'model.pt')) # torch.save(optimizer.state_dict(), os.path.join(config.save, 'optimizer.pt') ) # # cur_lr_decrease_patience = 0 # cur_patience = 0 # else: # cur_patience += 1 # if cur_patience >= config.patience: # stop_train = True # break # lr *= 0.75 # for param_group in optimizer.param_groups: # param_group['lr'] = lr # if lr < config.init_lr * 1e-2: # stop_train = True # break # cur_patience = 0 # cur_lr_decrease_patience += 1 # if cur_lr_decrease_patience >= config.lr_decrease_patience: # lr *= 0.75 # cur_early_stop_patience +=1 # if cur_early_stop_patience >= config.early_stop_patience: # stop_train = True # break # for param_group in optimizer.param_groups: # param_group['lr'] = lr if stop_train: break logging('best_dev_F1 {}'.format(best_dev_F1)) # delete last temporary trained model, since the training has completed print('Deleting last temp model files...') # for dp, dn, filenames in os.walk('.'): # for ff in filenames: # if ff.endswith("last_model.pt"): os.remove('temp_model') os.remove('temp_optimizer')
print(s) if log_: with open(os.path.join(config.save, 'log.txt'), 'a+') as f_log: f_log.write(s + '\n') if config.data_split == 'dev': # 我们暂时用的是dev # 不知道data_split的含义 dev_buckets = get_buckets(config.dev_record_file) para_limit = config.para_limit ques_limit = config.ques_limit elif config.data_split == 'test': para_limit = None ques_limit = None dev_buckets = get_buckets(config.test_record_file) def build_dev_iterator(): return DataIterator(dev_buckets, config.batch_size, para_limit, ques_limit, config.char_limit, False, config.sent_limit) if config.sp_lambda > 0: model = SPModel(config, word_mat, char_mat) else: model = Model(config, word_mat, char_mat) ori_model = model.cuda() ori_model.load_state_dict(torch.load(os.path.join(config.save, 'model.pt'))) model = nn.DataParallel(ori_model) model.eval() predict(build_dev_iterator(), model, dev_eval_file, config, config.prediction_file)
def train(): # 1. 数据集加载 with open(Config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(Config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(Config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(Config.idx2word_file, 'r') as fh: idx2word_dict = json.load(fh) train_buckets = [torch.load(Config.train_record_file)] dev_buckets = [torch.load(Config.dev_record_file)] # (self, buckets, bsz, para_limit, ques_limit, char_limit, shuffle, sent_limit def build_train_iterator(): return DataIterator(train_buckets, Config.batch_size, Config.para_limit, Config.ques_limit, Config.char_limit, True, Config.sent_limit) def build_dev_iterator(): return DataIterator(dev_buckets, Config.batch_size, Config.para_limit, Config.ques_limit, Config.char_limit, False, Config.sent_limit) if Config.sp_lambda > 0: model = SPModel(word_mat, char_mat) else: model = Model(word_mat, char_mat) print('需要更新的参数量:{}'.format( sum([p.nelement() for p in model.parameters() if p.requires_grad]))) # 需要更新的参数量:235636 ori_model = model.to(Config.device) lr = Config.init_lr optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=Config.init_lr) cur_patience = 0 global_step = 0 best_dev_F1 = None stop_train = False model.train() for epoch in range(10000): for data in build_train_iterator(): global_step += 1 context_idxs = data['context_idxs'].to(Config.device) ques_idxs = data['ques_idxs'].to(Config.device) context_char_idxs = data['context_char_idxs'].to(Config.device) ques_char_idxs = data['ques_char_idxs'].to(Config.device) context_lens = data['context_lens'].to(Config.device) y1 = data['y1'].to(Config.device) y2 = data['y2'].to(Config.device) q_type = data['q_type'].to(Config.device) is_support = data['is_support'].to(Config.device) start_mapping = data['start_mapping'].to(Config.device) end_mapping = data['end_mapping'].to(Config.device) all_mapping = data['all_mapping'].to(Config.device) # print(context_idxs.size()) # torch.Size([2, 942]) # print(context_char_idxs.size()) # torch.Size([2, 942, 16]) # print(ques_idxs.size()) # torch.Size([2, 18]) # print(ques_char_idxs.size()) # torch.Size([2, 18, 16]) # print(y1.size()) # torch.Size([2]) # print(y2.size()) # torch.Size([2]) # print(q_type.size()) # torch.Size([2]) # print(is_support.size()) # torch.Size([2, 49]) # print(start_mapping.size()) # torch.Size([2, 942, 49]) # print(end_mapping.size()) # torch.Size([2, 942, 49]) # print(all_mapping.size()) # torch.Size([2, 942, 49]) logit1, logit2, predict_type, predict_support = model( context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False) loss_1 = (nll_sum(predict_type, q_type) + nll_sum(logit1, y1) + nll_sum(logit2, y2)) / context_idxs.size(0) loss_2 = nll_average(predict_support.view(-1, 2), is_support.view(-1)) loss = loss_1 + Config.sp_lambda * loss_2 optimizer.zero_grad() loss.backward() optimizer.step() print( '| epoch {:3d} | step {:6d} | lr {:05.5f} | train loss {:8.3f}' .format(epoch, global_step, lr, loss)) if global_step % 10 == 0: model.eval() metrics = evaluate_batch(build_dev_iterator(), model, 0, dev_eval_file) model.train() print('dev loss {:8.3f} | EM {:.4f} | F1 {:.4f}'.format( epoch, metrics['loss'], metrics['exact_match'], metrics['f1'])) dev_F1 = metrics['f1'] if best_dev_F1 is None or dev_F1 > best_dev_F1: best_dev_F1 = dev_F1 torch.save(ori_model.state_dict(), os.path.join(Config.save, 'model.pt')) cur_patience = 0 else: cur_patience += 1 if cur_patience >= Config.patience: lr /= 2.0 for param_group in optimizer.param_groups: param_group['lr'] = lr if lr < Config.init_lr * 1e-2: stop_train = True break cur_patience = 0 if stop_train: break print('best_dev_F1 {}'.format(best_dev_F1))