def train_entry(config): from models import BiDAF with open(config.word_emb_file, "rb") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("Building model...") train_dataset = get_loader(config.train_record_file, config.batch_size) dev_dataset = get_loader(config.dev_record_file, config.batch_size) c_vocab_size, c_emb_size = char_mat.shape model = BiDAF(word_mat, w_embedding_size=300, c_embeding_size=c_emb_size, c_vocab_size=c_vocab_size, hidden_size=100, drop_prob=0.2).to(device) if config.pretrained: print("load pre-trained model") state_dict = torch.load(config.save_path, map_location="cpu") model.load_state_dict(state_dict) ema = EMA(config.decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adadelta(lr=0.5, params=parameters) best_f1 = 0 best_em = 0 patience = 0 for iter in range(config.num_epoch): train(model, optimizer, train_dataset, dev_dataset, dev_eval_file, iter, ema) ema.assign(model) metrics = test(model, dev_dataset, dev_eval_file, (iter + 1) * len(train_dataset)) dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_f1 = max(best_f1, dev_f1) best_em = max(best_em, dev_em) fn = os.path.join( config.save_dir, "model_{}_{:.2f}_{:.2f}.pt".format(iter, best_f1, best_em)) torch.save(model.state_dict(), fn) ema.resume(model)
def main(args): save_dir = os.path.join("./save", time.strftime("%m%d%H%M%S")) if not os.path.exists(save_dir): os.makedirs(save_dir) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") if args.all_data: data_loader = get_ext_data_loader(tokenizer, "./data/train/", shuffle=True, args=args) else: data_loader, _, _ = get_data_loader(tokenizer, "./data/train-v1.1.json", shuffle=True, args=args) vocab_size = len(tokenizer.vocab) if args.bidaf: print("train bidaf") model = BiDAF(embedding_size=args.embedding_size, vocab_size=vocab_size, hidden_size=args.hidden_size, drop_prob=args.dropout) else: ntokens = len(tokenizer.vocab) model = QANet(ntokens, embedding=args.embedding, embedding_size=args.embedding_size, hidden_size=args.hidden_size, num_head=args.num_head) if args.load_model: state_dict = torch.load(args.model_path, map_location="cpu") model.load_state_dict(state_dict) print("load pre-trained model") device = torch.device("cuda") model = model.to(device) model.train() ema = EMA(model, args.decay) base_lr = 1 parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(lr=base_lr, betas=(0.9, 0.999), eps=1e-7, weight_decay=5e-8, params=parameters) cr = args.lr / math.log2(args.lr_warm_up_num) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * math.log2(ee + 1) if ee < args.lr_warm_up_num else args.lr) step = 0 num_batches = len(data_loader) avg_loss = 0 best_f1 = 0 for epoch in range(1, args.num_epochs + 1): step += 1 start = time.time() model.train() for i, batch in enumerate(data_loader, start=1): c_ids, q_ids, start_positions, end_positions = batch c_len = torch.sum(torch.sign(c_ids), 1) max_c_len = torch.max(c_len) c_ids = c_ids[:, :max_c_len].to(device) q_len = torch.sum(torch.sign(q_ids), 1) max_q_len = torch.max(q_len) q_ids = q_ids[:, :max_q_len].to(device) start_positions = start_positions.to(device) end_positions = end_positions.to(device) optimizer.zero_grad() loss = model(c_ids, q_ids, start_positions=start_positions, end_positions=end_positions) loss.backward() avg_loss = cal_running_avg_loss(loss.item(), avg_loss) nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step(step) ema(model, step // args.batch_size) batch_size = c_ids.size(0) step += batch_size msg = "{}/{} {} - ETA : {} - qa_loss: {:.4f}" \ .format(i, num_batches, progress_bar(i, num_batches), eta(start, i, num_batches), avg_loss) print(msg, end="\r") if not args.debug: metric_dict = eval_qa(args, model) f1 = metric_dict["f1"] em = metric_dict["exact_match"] print("epoch: {}, final loss: {:.4f}, F1:{:.2f}, EM:{:.2f}".format( epoch, avg_loss, f1, em)) if args.bidaf: model_name = "bidaf" else: model_name = "qanet" if f1 > best_f1: best_f1 = f1 state_dict = model.state_dict() save_file = "{}_{:.2f}_{:.2f}".format(model_name, f1, em) path = os.path.join(save_dir, save_file) torch.save(state_dict, path)