def train_entry(config): from models import QANet with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.train_eval_file, "r") as fh: train_eval_file = json.load(fh) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("Building model...") train_dataset = SQuADDataset(config.train_record_file, config.num_steps, config.batch_size) dev_dataset = SQuADDataset(config.dev_record_file, config.test_num_batches, config.batch_size) lr = config.learning_rate base_lr = 1.0 warm_up = config.lr_warm_up_num model = QANet(word_mat, char_mat).to(device) ema = EMA(config.ema_decay) for name, p in model.named_parameters(): if p.requires_grad: ema.set(name, p) params = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(lr=base_lr, betas=(config.beta1, config.beta2), eps=1e-7, weight_decay=3e-7, params=params) cr = lr / log2(warm_up) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * log2(ee + 1) if ee < warm_up else lr) L = config.checkpoint N = config.num_steps best_f1 = best_em = patience = 0 for iter in range(0, N, L): train(model, optimizer, scheduler, ema, train_dataset, iter, L) valid(model, train_dataset, train_eval_file) metrics = test(model, dev_dataset, dev_eval_file) print("Learning rate: {}".format(scheduler.get_lr())) dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_f1 = max(best_f1, dev_f1) best_em = max(best_em, dev_em) fn = os.path.join(config.save_dir, f"model_iter={iter}_best_f1={best_f1}.pt") torch.save(model, fn)
def train_entry(config): from models import QANet with open(config.word_emb_file, "rb") as fh: word_mat = np.array(pickle.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(pickle.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("Building model...") train_dataset = SQuADDataset(config.train_record_file, config.num_steps, config.batch_size) dev_dataset = SQuADDataset(config.dev_record_file, config.val_num_batches, config.batch_size) lr = config.learning_rate base_lr = 1 lr_warm_up_num = config.lr_warm_up_num model = QANet(word_mat, char_mat).to(device) parameters = filter(lambda param: param.requires_grad, model.parameters()) #optimizer = optim.Adam(lr=base_lr, betas=(0.8, 0.999), eps=1e-7, weight_decay=3e-7, params=parameters) #optimizer = optim.SparseAdam(lr=lr, betas=(0.8, 0.999), eps=1e-7, params=parameters) optimizer = optim.Adam(lr=lr, params=parameters) #cr = lr / math.log2(lr_warm_up_num) #scheduler = optim.lr_scheduler.LambdaLR( # optimizer, # lr_lambda=lambda ee: cr * math.log2(ee + 1) if ee < lr_warm_up_num else lr) scheduler='' L = config.checkpoint N = config.num_steps best_f1 = 0 best_em = 0 patience = 0 unused = False for iter in range(0, N, L): train(model, optimizer, scheduler, train_dataset, iter, L) metrics = test(model, dev_dataset, dev_eval_file) if iter + L >= lr_warm_up_num - 1 and unused: optimizer.param_groups[0]['initial_lr'] = lr scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.99997) unused = False if config.print_weight: print_weight(model, 5, iter + L) #print("Learning rate: {}".format(scheduler.get_lr())) dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_f1 = max(best_f1, dev_f1) best_em = max(best_em, dev_em) fn = os.path.join(config.save_dir, "model.pt") torch.save(model, fn)
def train_entry(config): from models import QANet with open(config.word_emb_file, "rb") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("Building model...") train_dataset = get_loader(config.train_record_file, config.batch_size) dev_dataset = get_loader(config.dev_record_file, config.batch_size) lr = config.learning_rate base_lr = 1 lr_warm_up_num = config.lr_warm_up_num model = QANet(word_mat, char_mat).to(device) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(lr=base_lr, betas=(0.8, 0.999), eps=1e-6, weight_decay=3e-7, params=parameters) cr = lr / math.log2(lr_warm_up_num) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * math.log2(ee + 1) if ee < lr_warm_up_num else lr) best_f1 = 0 best_em = 0 patience = 0 unused = False for iter in range(config.num_epoch): train(model, optimizer, scheduler, train_dataset, dev_dataset, dev_eval_file, iter) metrics = test(model, dev_dataset, dev_eval_file, (iter + 1) * len(train_dataset)) dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_f1 = max(best_f1, dev_f1) best_em = max(best_em, dev_em) fn = os.path.join(config.save_dir, "model.pt") torch.save(model, fn)
def train(epoch, data, model=None): if model == None: model = QANet(data).to(device) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(betas=(0.8, 0.999), eps=1e-7, weight_decay=3e-7, params=parameters) crit = lr / math.log2(1000) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: crit * math.log2(ee + 1) if ee + 1 <= 1000 else lr) packs = trunk(data.train.packs, batch_size) f_log = open("log/model.log", "w") # test(model, data, ep=0, iter=0, test_num=10, test_size=50, f_log=f_log) try: for ep in range(epoch): print("EPOCH {:02d}: ".format(ep)) l = len(packs) for i in tqdm(range(l)): pack = packs[i] Cw, Cc, Qw, Qc, a = to_batch(pack, data, data.train) optimizer.zero_grad() out1, out2 = model(Cw, Cc, Qw, Qc) loss1 = F.cross_entropy(out1, a[:, 0]) loss2 = F.cross_entropy(out2, a[:, 1]) loss = loss1 + loss2 writer.add_scalar("data/loss", float(loss), ep * l + i) loss.backward() scheduler.step() if (i + 1) % checkpoint == 0: torch.save( model, os.path.join( model_dir, "model-tmp-{:02d}-{}.pt".format(ep, i + 1))) # test(model, data, ep, i, 10, 50, f_log) # test(model, data, ep, i, 1, -1, f_log) random.shuffle(packs) torch.save(model, os.path.join(model_dir, model_fn)) writer.close() except Exception as e: torch.save( model, os.path.join(model_dir, "model-{:02d}-{}.pt".format(ep, i + 1))) raise e except KeyboardInterrupt as k: torch.save( model, os.path.join(model_dir, "model-{:02d}-{}.pt".format(ep, i + 1))) return model
def train(args, config): # load data print("Loading data") train_dataset = data.get_loader(config.train_file, config.batch_size, config.version) dev_dataset = data.get_loader(config.dev_file, config.val_batch_size, config.version) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("loading embeddings") with open(config.word_emb_file, "rb") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) # create model print("creating model") if config.model_type == "model6": Model = model_utils.get_model_func(config.model_type, config.version) # get pretrained model from config # takes a pretrained of QANet model_dict_ = torch.load(config.pretrained_model, pickle_module=dill)["model"] model_dict = {} for key in model_dict_: model_dict[key[7:]] = model_dict_[key] from models import QANet pretrained_model = QANet(word_mat, char_mat, config) # load its state model_data = pretrained_model.state_dict() model_data.update(model_dict) pretrained_model.load_state_dict(model_data) model = Model(pretrained_model, config).to(config.device) model = torch.nn.DataParallel(model) else: Model = model_utils.get_model_func(config.model_type, config.version) model = Model(word_mat, char_mat, config).to(config.device) model = torch.nn.DataParallel(model) print("Training Model") trainer = Trainer(model, train_dataset, dev_dataset, dev_eval_file, config) if args.model_file is not None: trainer.load(args.model_file) trainer.ema.resume(trainer.model) trainer.train()
def test_build(self): vocab_size = 3000 embed_size = filters = 96 context_limit = 40 query_limit = 5 num_heads = 1 model = QANet(vocab_size, embed_size, filters, num_heads, context_limit, query_limit, dropout=.1, encoder_layer_size=1, encoder_conv_blocks=3, output_layer_size=7, output_conv_blocks=2).build() query_input, context_input = model.inputs start_prob, end_prob, S_bar, S_T = model.outputs self.assertTupleEqual(K.int_shape(query_input), (None, query_limit)) self.assertTupleEqual(K.int_shape(context_input), (None, context_limit)) self.assertTupleEqual(K.int_shape(start_prob), (None, context_limit)) self.assertTupleEqual(K.int_shape(end_prob), (None, context_limit))
def test_entry(config): # with open(config.dev_eval_file, "r") as fh: # dev_eval_file = json.load(fh) # dev_dataset = get_loader(config.dev_record_file, config.batch_size) # fn = os.path.join(config.save_dir, "model.pt") from models import QANet with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) dev_dataset = get_loader(config.dev_record_file, config.batch_size) fn = os.path.join(config.save_dir, "model.pt") with open(config.word_emb_file, "rb") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) # import pdb model = QANet(word_mat, char_mat).cuda() # # pdb.set_trace() # # torch.save(model.state_dict(),'model_ph.pkl') # model.load_state_dict(torch.load('/disc1/nuo.chen/QANet-pytorch-/model_ph.pkl')) # for k,v in model.named_parameters(): # print(k,v.view(-1)[0]) # pdb.set_trace() # # model = torch.load(fn).to(device) # if torch.cuda.device_count() > 1: # print('i can use gpu') # model = torch.nn.DataParallel(model,device_ids=[0,1,2,3]) # if isinstance(model,torch.nn.DataParallel): # model=model.module # torch.save(model.state_dict(),'model_ph.pkl') # pdb.set_trace() model = torch.load(fn) test(model, dev_dataset, dev_eval_file, 0)
def main(args): token_to_index, index_to_token = Vocabulary.load(args.vocab_file) root, _ = os.path.splitext(args.vocab_file) basepath, basename = os.path.split(root) embed_path = f'{basepath}/embedding_{basename}.npy' embeddings = np.load(embed_path) if os.path.exists(embed_path) else None batch_size = args.batch # Batch size for training. epochs = args.epoch # Number of epochs to train for. model = QANet(len(token_to_index), args.embed, args.hidden, args.num_heads, encoder_num_blocks=args.encoder_layer, encoder_num_convs=args.encoder_conv, output_num_blocks=args.output_layer, output_num_convs=args.output_conv, dropout=args.dropout, embeddings=embeddings).build() opt = Adam(lr=0.001, beta_1=0.8, beta_2=0.999, epsilon=1e-7, clipnorm=5.) model.compile(optimizer=opt, loss=[ 'sparse_categorical_crossentropy', 'sparse_categorical_crossentropy', None, None ], loss_weights=[1, 1, 0, 0]) train_dataset = SquadReader(args.train_path) dev_dataset = SquadReader(args.dev_path) converter = SquadConverter(token_to_index, PAD_TOKEN, UNK_TOKEN, lower=args.lower) train_generator = Iterator(train_dataset, batch_size, converter) dev_generator = Iterator(dev_dataset, batch_size, converter) trainer = SquadTrainer(model, train_generator, epochs, dev_generator, './model/qanet.{epoch:02d}-{val_loss:.2f}.h5') trainer.add_callback(BatchLearningRateScheduler()) # trainer.add_callback(ExponentialMovingAverage(0.999)) if args.use_tensorboard: trainer.add_callback( TensorBoard(log_dir='./graph', batch_size=batch_size)) history = trainer.run() dump_graph(history, 'loss_graph.png')
def test_model(): args = get_setup_args() word_vectors = util.torch_from_json(args.word_emb_file) with open(args.char2idx_file, "r") as f: char2idx = json_load(f) model = QANet(word_vectors, char2idx) cw_idxs = torch.randint(2, 1000, (64, 374)) cc_idxs = torch.randint(2, 50, (64, 374, 200)) qw_idxs = torch.randint(2, 1000, (64, 70)) qc_idxs = torch.randint(2, 50, (64, 70, 200)) cw_idxs[:, 0] = 1 cw_idxs[3, -1] = 0 qw_idxs[:, 0] = 1 qw_idxs[3, -1] = 0 out = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) print(out)
def get_model(word_vectors, char_vectors, log, args): if args.model_name == "BiDAF": model = BiDAF_char(word_vectors=word_vectors, char_vectors=char_vectors, char_len=args.char_len, hidden_size=args.hidden_size, drop_prob=args.drop_prob) elif args.model_name == "BiDAF_nochar": model = BiDAF(word_vectors=word_vectors, hidden_size=args.hidden_size, drop_prob=args.drop_prob) elif args.model_name == "KnowGQA": model = KnowGQA(word_vectors=word_vectors, char_vectors=char_vectors, char_len=args.char_len, hidden_size=args.hidden_size, h=args.h, drop_prob=args.drop_prob) elif args.model_name == "QANet": model = QANet(word_vectors=word_vectors, char_vectors=char_vectors, char_len=args.char_len, hidden_size=args.hidden_size, h=args.h, drop_prob=args.drop_prob) else: raise ValueError("Model name doesn't exist.") model = nn.DataParallel(model, args.gpu_ids) ''' for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) ''' if args.load_path: log.info(f'Loading checkpoint from {args.load_path}...') model, step = util.load_model(model, args.load_path, args.gpu_ids) else: step = 0 return model, step
def train(config): with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.train_eval_file, "r") as fh: train_eval_file = json.load(fh) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.dev_meta, "r") as fh: meta = json.load(fh) dev_total = meta["total"] print("Building model...") parser = get_record_parser(config) train_dataset = get_batch_dataset(config.train_record_file, parser, config) dev_dataset = get_dataset(config.dev_record_file, parser, config) model = QANet(word_mat, char_mat) loss_save = 100.0 patience = 0 best_f1 = 0. best_em = 0. with tf.Session(config=sess_config) as sess: writer = tf.summary.FileWriter(config.log_dir) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() train_handle = sess.run(train_iterator.string_handle()) dev_handle = sess.run(dev_iterator.string_handle()) if os.path.exists(os.path.join(config.save_dir, "checkpoint")): saver.restore(sess, tf.train.latest_checkpoint(config.save_dir)) global_step = max(sess.run(model.global_step), 1) for _ in tqdm(range(global_step, config.num_steps + 1)): global_step = sess.run(model.global_step) + 1 loss, train_op = sess.run([model.loss, model.train_op], feed_dict={ handle: train_handle, model.dropout: config.dropout}) if global_step % config.period == 0: loss_sum = tf.Summary(value=[tf.Summary.Value( tag="model/loss", simple_value=loss), ]) writer.add_summary(loss_sum, global_step) if global_step % config.checkpoint == 0: _, summ = evaluate_batch( model, config.val_num_batches, train_eval_file, sess, "train", handle, train_handle) for s in summ: writer.add_summary(s, global_step) metrics, summ = evaluate_batch( model, dev_total // config.batch_size + 1, dev_eval_file, sess, "dev", handle, dev_handle) dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_em = max(best_em, dev_em) best_f1 = max(best_f1, dev_f1) for s in summ: writer.add_summary(s, global_step) writer.flush() filename = os.path.join( config.save_dir, "model_{}.ckpt".format(global_step)) saver.save(sess, filename)
def main(args): # Set up logging and devices args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True) log = util.get_logger(args.save_dir, args.name) tbx = SummaryWriter(args.save_dir) device, args.gpu_ids = util.get_available_devices() log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True))) args.batch_size *= max(1, len(args.gpu_ids)) # Set random seed log.info('Using random seed {}...'.format(args.seed)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) # setup_args = get_setup_args() with open(args.char2idx_file, "r") as f: char2idx = json_load(f) # Get model log.info('Building model...') model = QANet(word_vectors=word_vectors, char2idx=char2idx) model = nn.DataParallel(model, args.gpu_ids) if args.load_path: log.info('Loading checkpoint from {}...'.format(args.load_path)) model, step = util.load_model(model, args.load_path, args.gpu_ids) else: step = 0 model = model.to(device) model.train() ema = util.EMA(model, args.ema_decay) # Get saver saver = util.CheckpointSaver(args.save_dir, max_checkpoints=args.max_checkpoints, metric_name=args.metric_name, maximize_metric=args.maximize_metric, log=log) # Get optimizer and scheduler optimizer = optim.Adadelta(model.parameters(), args.lr, weight_decay=args.l2_wd) # optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.8, 0.999)) scheduler = sched.LambdaLR(optimizer, lambda s: 1.) # Constant LR # Get data loader log.info('Building dataset...') dataset1 = np.load(args.train_record_file) outfile = '/content/dataset/train_light.npz' n_row = 75000 np.savez(outfile, context_idxs=dataset1['context_idxs'][:n_row], context_char_idxs=dataset1['context_char_idxs'][:n_row], ques_idxs=dataset1['ques_idxs'][:n_row], ques_char_idxs=dataset1['ques_char_idxs'][:n_row], y1s=dataset1['y1s'][:n_row], y2s=dataset1['y2s'][:n_row], ids=dataset1['ids'][:n_row]) # train_dataset = SQuAD(args.train_record_file, args.use_squad_v2) train_dataset = SQuAD(outfile, args.use_squad_v2) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2) dev_loader = data.DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Train log.info('Training...') steps_till_eval = args.eval_steps epoch = step // len(train_dataset) while epoch != args.num_epochs: epoch += 1 log.info('Starting epoch {}...'.format(epoch)) with torch.enable_grad(), \ tqdm(total=len(train_loader.dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) batch_size = cw_idxs.size(0) optimizer.zero_grad() # Forward log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) y1, y2 = y1.to(device), y2.to(device) loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2) loss_val = loss.item() if step % 10000 == 0: print('loss val: {}'.format(loss_val)) # Backward loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step(step // batch_size) ema(model, step // batch_size) # Log info step += batch_size progress_bar.update(batch_size) progress_bar.set_postfix(epoch=epoch, NLL=loss_val) tbx.add_scalar('train/NLL', loss_val, step) tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'], step) steps_till_eval -= batch_size if steps_till_eval <= 0: steps_till_eval = args.eval_steps # Evaluate and save checkpoint log.info('Evaluating at step {}...'.format(step)) ema.assign(model) results, pred_dict = evaluate(model, dev_loader, device, args.dev_eval_file, args.max_ans_len, args.use_squad_v2) saver.save(step, model, results[args.metric_name], device) ema.resume(model) # Log to console results_str = ', '.join('{}: {:05.2f}'.format(k, v) for k, v in results.items()) log.info('Dev {}'.format(results_str)) # Log to TensorBoard log.info('Visualizing in TensorBoard...') for k, v in results.items(): tbx.add_scalar('dev/{}'.format(k), v, step) util.visualize(tbx, pred_dict=pred_dict, eval_path=args.dev_eval_file, step=step, split='dev', num_visuals=args.num_visuals)
def main(args): # Set up logging args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False) log = util.get_logger(args.save_dir, args.name) log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}') device, gpu_ids = util.get_available_devices() args.batch_size *= max(1, len(gpu_ids)) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vec = util.torch_from_json(args.char_emb_file) # Get model log.info('Building model...') if args.name == 'baseline': model = BiDAF(word_vectors=word_vectors, hidden_size=args.hidden_size) elif args.name == 'charembeddings': model = BiDAFChar(word_vectors=word_vectors, char_vec=char_vec, word_len=16, hidden_size=args.hidden_size) elif args.name == 'charembeddings2': model = BiDAFChar2(word_vectors=word_vectors, char_vec=char_vec, word_len=16, hidden_size=args.hidden_size) elif args.name == 'qanet': model = QANet(word_vectors=word_vectors, char_vec=char_vec, word_len=16, emb_size=args.hidden_size, enc_size=args.enc_size, n_head=args.n_head, LN_train=args.ln_train, DP_residual=args.dp_res, mask_pos=args.mask_pos, two_pos=args.two_pos, final_prob=1) elif args.name == 'qanet2': model = QANet2(word_vectors=word_vectors, char_vec=char_vec, word_len=16, emb_size=args.hidden_size, enc_size=args.enc_size, n_head=args.n_head, LN_train=args.ln_train, DP_residual=args.dp_res, mask_pos=args.mask_pos, two_pos=args.two_pos, final_prob=1) elif args.name == 'qanet3': model = QANet3(word_vectors=word_vectors, char_vec=char_vec, word_len=16, emb_size=args.hidden_size, enc_size=args.enc_size, n_head=args.n_head, LN_train=args.ln_train, DP_residual=args.dp_res, mask_pos=args.mask_pos, two_pos=args.two_pos, final_prob=1) else: raise ValueError('Wrong model name') model = nn.DataParallel(model, gpu_ids) log.info(f'Loading checkpoint from {args.load_path}...') model = util.load_model(model, args.load_path, gpu_ids, return_step=False) model = model.to(device) model.eval() # Get data loader log.info('Building dataset...') record_file = vars(args)[f'{args.split}_record_file'] dataset = SQuAD(record_file, args.use_squad_v2) data_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Evaluate log.info(f'Evaluating on {args.split} split...') nll_meter = util.AverageMeter() pred_dict = {} # Predictions for TensorBoard sub_dict = {} # Predictions for submission eval_file = vars(args)[f'{args.split}_eval_file'] with open(eval_file, 'r') as fh: gold_dict = json_load(fh) with torch.no_grad(), \ tqdm(total=len(dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) batch_size = cw_idxs.size(0) # Forward log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) y1, y2 = y1.to(device), y2.to(device) loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2) nll_meter.update(loss.item(), batch_size) # Get F1 and EM scores p1, p2 = log_p1.exp(), log_p2.exp() starts, ends = util.discretize(p1, p2, args.max_ans_len, args.use_squad_v2) # Log info progress_bar.update(batch_size) if args.split != 'test': # No labels for the test set, so NLL would be invalid progress_bar.set_postfix(NLL=nll_meter.avg) idx2pred, uuid2pred = util.convert_tokens(gold_dict, ids.tolist(), starts.tolist(), ends.tolist(), args.use_squad_v2) pred_dict.update(idx2pred) sub_dict.update(uuid2pred) # Log results (except for test set, since it does not come with labels) if args.split != 'test': results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2) results_list = [('NLL', nll_meter.avg), ('F1', results['F1']), ('EM', results['EM'])] if args.use_squad_v2: results_list.append(('AvNA', results['AvNA'])) results = OrderedDict(results_list) # Log to console results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items()) log.info(f'{args.split.title()} {results_str}') # Log to TensorBoard tbx = SummaryWriter(args.save_dir) util.visualize(tbx, pred_dict=pred_dict, eval_path=eval_file, step=0, split=args.split, num_visuals=args.num_visuals) # Write submission file sub_path = join(args.save_dir, args.split + '_' + args.sub_file) log.info(f'Writing submission file to {sub_path}...') with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh: csv_writer = csv.writer(csv_fh, delimiter=',') csv_writer.writerow(['Id', 'Predicted']) for uuid in sorted(sub_dict): csv_writer.writerow([uuid, sub_dict[uuid]])
def main(args): # Set up logging args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False) log = util.get_logger(args.save_dir, args.name) log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}') device, gpu_ids = util.get_available_devices() args.batch_size *= max(1, len(gpu_ids)) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vectors = util.torch_from_json(args.char_emb_file) # Get model log.info('Building model...') nbr_model = 0 if (args.load_path_baseline): model_baseline = Baseline(word_vectors=word_vectors, hidden_size=100) model_baseline = nn.DataParallel(model_baseline, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_baseline}...') model_baseline = util.load_model(model_baseline, args.load_path_baseline, gpu_ids, return_step=False) model_baseline = model_baseline.to(device) model_baseline.eval() nll_meter_baseline = util.AverageMeter() nbr_model += 1 save_prob_baseline_start = [] save_prob_baseline_end = [] if (args.load_path_bidaf): model_bidaf = BiDAF(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size) model_bidaf = nn.DataParallel(model_bidaf, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_bidaf}...') model_bidaf = util.load_model(model_bidaf, args.load_path_bidaf, gpu_ids, return_step=False) model_bidaf = model_bidaf.to(device) model_bidaf.eval() nll_meter_bidaf = util.AverageMeter() nbr_model += 1 save_prob_bidaf_start = [] save_prob_bidaf_end = [] if (args.load_path_bidaf_fusion): model_bidaf_fu = BiDAF_fus(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size) model_bidaf_fu = nn.DataParallel(model_bidaf_fu, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_bidaf_fusion}...') model_bidaf_fu = util.load_model(model_bidaf_fu, args.load_path_bidaf_fusion, gpu_ids, return_step=False) model_bidaf_fu = model_bidaf_fu.to(device) model_bidaf_fu.eval() nll_meter_bidaf_fu = util.AverageMeter() nbr_model += 1 save_prob_bidaf_fu_start = [] save_prob_bidaf_fu_end = [] if (args.load_path_qanet): model_qanet = QANet(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks, divisor_dim_kqv=args.divisor_dim_kqv) model_qanet = nn.DataParallel(model_qanet, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet}...') model_qanet = util.load_model(model_qanet, args.load_path_qanet, gpu_ids, return_step=False) model_qanet = model_qanet.to(device) model_qanet.eval() nll_meter_qanet = util.AverageMeter() nbr_model += 1 save_prob_qanet_start = [] save_prob_qanet_end = [] if (args.load_path_qanet_old): model_qanet_old = QANet_old(word_vectors=word_vectors, char_vectors=char_vectors, device=device, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks) model_qanet_old = nn.DataParallel(model_qanet_old, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet_old}...') model_qanet_old = util.load_model(model_qanet_old, args.load_path_qanet_old, gpu_ids, return_step=False) model_qanet_old = model_qanet_old.to(device) model_qanet_old.eval() nll_meter_qanet_old = util.AverageMeter() nbr_model += 1 save_prob_qanet_old_start = [] save_prob_qanet_old_end = [] if (args.load_path_qanet_inde): model_qanet_inde = QANet_independant_encoder( word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks, divisor_dim_kqv=args.divisor_dim_kqv) model_qanet_inde = nn.DataParallel(model_qanet_inde, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet_inde}...') model_qanet_inde = util.load_model(model_qanet_inde, args.load_path_qanet_inde, gpu_ids, return_step=False) model_qanet_inde = model_qanet_inde.to(device) model_qanet_inde.eval() nll_meter_qanet_inde = util.AverageMeter() nbr_model += 1 save_prob_qanet_inde_start = [] save_prob_qanet_inde_end = [] if (args.load_path_qanet_s_e): model_qanet_s_e = QANet_S_E(word_vectors=word_vectors, char_vectors=char_vectors, char_emb_dim=args.char_emb_dim, hidden_size=args.hidden_size, n_heads=args.n_heads, n_conv_emb_enc=args.n_conv_emb, n_conv_mod_enc=args.n_conv_mod, n_emb_enc_blocks=args.n_emb_blocks, n_mod_enc_blocks=args.n_mod_blocks, divisor_dim_kqv=args.divisor_dim_kqv) model_qanet_s_e = nn.DataParallel(model_qanet_s_e, gpu_ids) log.info(f'Loading checkpoint from {args.load_path_qanet_s_e}...') model_qanet_s_e = util.load_model(model_qanet_s_e, args.load_path_qanet_s_e, gpu_ids, return_step=False) model_qanet_s_e = model_qanet_s_e.to(device) model_qanet_s_e.eval() nll_meter_qanet_s_e = util.AverageMeter() nbr_model += 1 save_prob_qanet_s_e_start = [] save_prob_qanet_s_e_end = [] # Get data loader log.info('Building dataset...') record_file = vars(args)[f'{args.split}_record_file'] dataset = SQuAD(record_file, args.use_squad_v2) data_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Evaluate log.info(f'Evaluating on {args.split} split...') pred_dict = {} # Predictions for TensorBoard sub_dict = {} # Predictions for submission eval_file = vars(args)[f'{args.split}_eval_file'] with open(eval_file, 'r') as fh: gold_dict = json_load(fh) with torch.no_grad(), \ tqdm(total=len(dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) cc_idxs = cc_idxs.to(device) qc_idxs = qc_idxs.to(device) batch_size = cw_idxs.size(0) y1, y2 = y1.to(device), y2.to(device) l_p1, l_p2 = [], [] # Forward if (args.load_path_baseline): log_p1_baseline, log_p2_baseline = model_baseline( cw_idxs, cc_idxs) loss_baseline = F.nll_loss(log_p1_baseline, y1) + F.nll_loss( log_p2_baseline, y2) nll_meter_baseline.update(loss_baseline.item(), batch_size) l_p1 += [log_p1_baseline.exp()] l_p2 += [log_p2_baseline.exp()] if (args.save_probabilities): save_prob_baseline_start += [ log_p1_baseline.exp().detach().cpu().numpy() ] save_prob_baseline_end += [ log_p2_baseline.exp().detach().cpu().numpy() ] if (args.load_path_qanet): log_p1_qanet, log_p2_qanet = model_qanet( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet = F.nll_loss(log_p1_qanet, y1) + F.nll_loss( log_p2_qanet, y2) nll_meter_qanet.update(loss_qanet.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet.exp()] l_p2 += [log_p2_qanet.exp()] if (args.save_probabilities): save_prob_qanet_start += [ log_p1_qanet.exp().detach().cpu().numpy() ] save_prob_qanet_end += [ log_p2_qanet.exp().detach().cpu().numpy() ] if (args.load_path_qanet_old): log_p1_qanet_old, log_p2_qanet_old = model_qanet_old( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet_old = F.nll_loss(log_p1_qanet_old, y1) + F.nll_loss( log_p2_qanet_old, y2) nll_meter_qanet_old.update(loss_qanet_old.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet_old.exp()] l_p2 += [log_p2_qanet_old.exp()] if (args.save_probabilities): save_prob_qanet_old_start += [ log_p1_qanet_old.exp().detach().cpu().numpy() ] save_prob_qanet_old_end += [ log_p2_qanet_old.exp().detach().cpu().numpy() ] if (args.load_path_qanet_inde): log_p1_qanet_inde, log_p2_qanet_inde = model_qanet_inde( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet_inde = F.nll_loss( log_p1_qanet_inde, y1) + F.nll_loss(log_p2_qanet_inde, y2) nll_meter_qanet_inde.update(loss_qanet_inde.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet_inde.exp()] l_p2 += [log_p2_qanet_inde.exp()] if (args.save_probabilities): save_prob_qanet_inde_start += [ log_p1_qanet_inde.exp().detach().cpu().numpy() ] save_prob_qanet_inde_end += [ log_p2_qanet_inde.exp().detach().cpu().numpy() ] if (args.load_path_qanet_s_e): log_p1_qanet_s_e, log_p2_qanet_s_e = model_qanet_s_e( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_qanet_s_e = F.nll_loss(log_p1_qanet_s_e, y1) + F.nll_loss( log_p2_qanet_s_e, y2) nll_meter_qanet_s_e.update(loss_qanet_s_e.item(), batch_size) # Get F1 and EM scores l_p1 += [log_p1_qanet_s_e.exp()] l_p2 += [log_p2_qanet_s_e.exp()] if (args.save_probabilities): save_prob_qanet_s_e_start += [ log_p1_qanet_s_e.exp().detach().cpu().numpy() ] save_prob_qanet_s_e_end += [ log_p2_qanet_s_e.exp().detach().cpu().numpy() ] if (args.load_path_bidaf): log_p1_bidaf, log_p2_bidaf = model_bidaf( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_bidaf = F.nll_loss(log_p1_bidaf, y1) + F.nll_loss( log_p2_bidaf, y2) nll_meter_bidaf.update(loss_bidaf.item(), batch_size) l_p1 += [log_p1_bidaf.exp()] l_p2 += [log_p2_bidaf.exp()] if (args.save_probabilities): save_prob_bidaf_start += [ log_p1_bidaf.exp().detach().cpu().numpy() ] save_prob_bidaf_end += [ log_p2_bidaf.exp().detach().cpu().numpy() ] if (args.load_path_bidaf_fusion): log_p1_bidaf_fu, log_p2_bidaf_fu = model_bidaf_fu( cw_idxs, cc_idxs, qw_idxs, qc_idxs) loss_bidaf_fu = F.nll_loss(log_p1_bidaf_fu, y1) + F.nll_loss( log_p2_bidaf_fu, y2) nll_meter_bidaf_fu.update(loss_bidaf_fu.item(), batch_size) l_p1 += [log_p1_bidaf_fu.exp()] l_p2 += [log_p2_bidaf_fu.exp()] if (args.save_probabilities): save_prob_bidaf_fu_start += [ log_p1_bidaf_fu.exp().detach().cpu().numpy() ] save_prob_bidaf_fu_end += [ log_p2_bidaf_fu.exp().detach().cpu().numpy() ] p1, p2 = l_p1[0], l_p2[0] for i in range(1, nbr_model): p1 += l_p1[i] p2 += l_p2[i] p1 /= nbr_model p2 /= nbr_model starts, ends = util.discretize(p1, p2, args.max_ans_len, args.use_squad_v2) # Log info progress_bar.update(batch_size) if args.split != 'test': # No labels for the test set, so NLL would be invalid if (args.load_path_qanet): progress_bar.set_postfix(NLL=nll_meter_qanet.avg) elif (args.load_path_bidaf): progress_bar.set_postfix(NLL=nll_meter_bidaf.avg) elif (args.load_path_bidaf_fusion): progress_bar.set_postfix(NLL=nll_meter_bidaf_fu.avg) elif (args.load_path_qanet_old): progress_bar.set_postfix(NLL=nll_meter_qanet_old.avg) elif (args.load_path_qanet_inde): progress_bar.set_postfix(NLL=nll_meter_qanet_inde.avg) elif (args.load_path_qanet_s_e): progress_bar.set_postfix(NLL=nll_meter_qanet_s_e.avg) else: progress_bar.set_postfix(NLL=nll_meter_baseline.avg) idx2pred, uuid2pred = util.convert_tokens(gold_dict, ids.tolist(), starts.tolist(), ends.tolist(), args.use_squad_v2) pred_dict.update(idx2pred) sub_dict.update(uuid2pred) if (args.save_probabilities): if (args.load_path_baseline): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_baseline_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_baseline_end, fp) if (args.load_path_bidaf): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_end, fp) if (args.load_path_bidaf_fusion): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_fu_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_bidaf_fu_end, fp) if (args.load_path_qanet): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_end, fp) if (args.load_path_qanet_old): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_old_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_old_end, fp) if (args.load_path_qanet_inde): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_inde_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_inde_end, fp) if (args.load_path_qanet_s_e): with open(args.save_dir + "/probs_start", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_s_e_start, fp) with open(args.save_dir + "/probs_end", "wb") as fp: #Pickling pickle.dump(save_prob_qanet_s_e_end, fp) # Log results (except for test set, since it does not come with labels) if args.split != 'test': results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2) if (args.load_path_qanet): meter_avg = nll_meter_qanet.avg elif (args.load_path_bidaf): meter_avg = nll_meter_bidaf.avg elif (args.load_path_bidaf_fusion): meter_avg = nll_meter_bidaf_fu.avg elif (args.load_path_qanet_inde): meter_avg = nll_meter_qanet_inde.avg elif (args.load_path_qanet_s_e): meter_avg = nll_meter_qanet_s_e.avg elif (args.load_path_qanet_old): meter_avg = nll_meter_qanet_old.avg else: meter_avg = nll_meter_baseline.avg results_list = [('NLL', meter_avg), ('F1', results['F1']), ('EM', results['EM'])] if args.use_squad_v2: results_list.append(('AvNA', results['AvNA'])) results = OrderedDict(results_list) # Log to console results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items()) log.info(f'{args.split.title()} {results_str}') # Log to TensorBoard tbx = SummaryWriter(args.save_dir) util.visualize(tbx, pred_dict=pred_dict, eval_path=eval_file, step=0, split=args.split, num_visuals=args.num_visuals) # Write submission file sub_path = join(args.save_dir, args.split + '_' + args.sub_file) log.info(f'Writing submission file to {sub_path}...') with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh: csv_writer = csv.writer(csv_fh, delimiter=',') csv_writer.writerow(['Id', 'Predicted']) for uuid in sorted(sub_dict): csv_writer.writerow([uuid, sub_dict[uuid]])
def main(args): # Set up logging and devices args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True) log = util.get_logger(args.save_dir, args.name) tbx = SummaryWriter(args.save_dir) device, args.gpu_ids = util.get_available_devices() log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True))) args.batch_size *= max(1, len(args.gpu_ids)) # Set random seed log.info('Using random seed {}...'.format(args.seed)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get embeddings # Args: word_vectors: word vector tensor of dimension [vocab_size * wemb_dim] log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vectors = util.torch_from_json(args.char_emb_file) # Get Model log.info('Building Model...') model = QANet(word_vectors, char_vectors, args.para_limit, args.ques_limit, args.f_model, num_head=args.num_head, train_cemb = (not args.pretrained_char)) model = nn.DataParallel(model, args.gpu_ids) if args.load_path: log.info('Loading checkpoint from {}...'.format(args.load_path)) model, step = util.load_model(model, args.load_path, args.gpu_ids) else: step = 0 model = model.to(device) model.train() ema = util.EMA(model, args.ema_decay) # Get saver saver = util.CheckpointSaver(args.save_dir, max_checkpoints=args.max_checkpoints, metric_name=args.metric_name, maximize_metric=args.maximize_metric, log=log) # Get optimizer and scheduler parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adam( params=parameters, lr=args.lr, betas=(args.beta1, args.beta2), eps=1e-8, weight_decay=3e-7) cr = 1.0 / math.log(args.lr_warm_up_num) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * math.log(ee + 1) if ee < args.lr_warm_up_num else 1) loss_f = torch.nn.CrossEntropyLoss() # Get data loader log.info('Building dataset...') train_dataset = SQuAD(args.train_record_file, args.use_squad_v2) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2) dev_loader = data.DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Train log.info('Training...') steps_till_eval = args.eval_steps epoch = step // len(train_dataset) while epoch != args.num_epochs: epoch += 1 log.info('Starting epoch {}...'.format(epoch)) with torch.enable_grad(), \ tqdm(total=len(train_loader.dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) cc_idxs = cc_idxs.to(device) qc_idxs = qc_idxs.to(device) batch_size = cw_idxs.size(0) optimizer.zero_grad() # Forward log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) y1, y2 = y1.to(device), y2.to(device) loss = torch.mean(loss_f(log_p1, y1) + loss_f(log_p2, y2)) loss_val = loss.item() # Backward loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step(step // batch_size) ema(model, step // batch_size) # Log info step += batch_size progress_bar.update(batch_size) progress_bar.set_postfix(epoch=epoch, NLL=loss_val) tbx.add_scalar('train/NLL', loss_val, step) tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'], step) steps_till_eval -= batch_size if steps_till_eval <= 0: steps_till_eval = args.eval_steps # Evaluate and save checkpoint log.info('Evaluating at step {}...'.format(step)) ema.assign(model) results, pred_dict = evaluate(model, dev_loader, device, args.dev_eval_file, args.max_ans_len, args.use_squad_v2) saver.save(step, model, results[args.metric_name], device) ema.resume(model) # Log to console results_str = ', '.join('{}: {:05.2f}'.format(k, v) for k, v in results.items()) log.info('Dev {}'.format(results_str)) # Log to TensorBoard log.info('Visualizing in TensorBoard...') for k, v in results.items(): tbx.add_scalar('dev/{}'.format(k), v, step) util.visualize(tbx, pred_dict=pred_dict, eval_path=args.dev_eval_file, step=step, split='dev', num_visuals=args.num_visuals)
def main(args): # Set up logging and devices args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True) log = util.get_logger(args.save_dir, args.name) tbx = SummaryWriter(args.save_dir) device, args.gpu_ids = util.get_available_devices() log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True))) args.batch_size *= max(1, len(args.gpu_ids)) # Set random seed log.info('Using random seed {}...'.format(args.seed)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get word embeddings log.info('Loading word embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) ### start our code: # Get char-embeddings log.info('Loading char-embeddings...') char_vectors = util.torch_from_json(args.char_emb_file) ### end our code # Get model log.info('Building model...') # model = BiDAF(word_vectors=word_vectors, # hidden_size=args.hidden_size, # char_vectors = char_vectors, # drop_prob=args.drop_prob) ### start our code: (QANet) model = QANet(word_vectors=word_vectors, char_vectors=char_vectors, hidden_size=args.hidden_size, kernel_size=7, filters=128, drop_prob=args.drop_prob) ### end our code model = nn.DataParallel(model, args.gpu_ids) if args.load_path: log.info('Loading checkpoint from {}...'.format(args.load_path)) model, step = util.load_model(model, args.load_path, args.gpu_ids) else: step = 0 model = model.to(device) model.train() ema = util.EMA(model, args.ema_decay) # Get saver saver = util.CheckpointSaver(args.save_dir, max_checkpoints=args.max_checkpoints, metric_name=args.metric_name, maximize_metric=args.maximize_metric, log=log) # Get optimizer and scheduler # https://pytorch.org/docs/stable/optim.html # Original: # optimizer = optim.Adadelta(model.parameters(), args.lr, # weight_decay=args.l2_wd) # # default: lr=0.5, rho=0.9, eps=1e-06, weight_decay=0 # scheduler = sched.LambdaLR(optimizer, lambda s: 1.) # Constant LR # A function which computes a multiplicative factor given an integer parameter epoch, # or a list of such functions, one for each group in optimizer.param_groups ### start our code: optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.8, 0.999), eps=1e-7, weight_decay=3e-7, amsgrad=False) scheduler = sched.LambdaLR(optimizer, lambda s: 1.) # Constant LR # Adabound # optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1) # scheduler = sched.LambdaLR(optimizer, lambda s: 1.) ### end our code # Get data loader log.info('Building dataset...') train_dataset = SQuAD(args.train_record_file, args.use_squad_v2) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2) dev_loader = data.DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Train log.info('Training...') steps_till_eval = args.eval_steps epoch = step // len(train_dataset) while epoch != args.num_epochs: epoch += 1 log.info('Starting epoch {}...'.format(epoch)) with torch.enable_grad(), \ tqdm(total=len(train_loader.dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader: # print("cc_idxs: ", cc_idxs) # print("qc_idxs: ", qc_idxs) # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) ### start our code: cc_idxs = cc_idxs.to(device) qc_idxs = qc_idxs.to(device) ### end our code batch_size = cw_idxs.size(0) optimizer.zero_grad() # Forward # log_p1, log_p2 = model(cw_idxs, qw_idxs) # original ### start our code: log_p1, log_p2 = model(cw_idxs, qw_idxs, cc_idxs, qc_idxs) ### end our code y1, y2 = y1.to(device), y2.to(device) loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2) loss_val = loss.item() # Backward loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) ### start our code: # optimizer = LR_decay(optimizer, epoch, lr = 0.8) # initial learning rate 0.8 if step < 1000: optimizer = LR_warmup(optimizer, step, lr=0.) ### end our code optimizer.step() scheduler.step(step // batch_size) ema(model, step // batch_size) # Log info step += batch_size progress_bar.update(batch_size) progress_bar.set_postfix(epoch=epoch, NLL=loss_val) tbx.add_scalar('train/NLL', loss_val, step) tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'], step) steps_till_eval -= batch_size if steps_till_eval <= 0: steps_till_eval = args.eval_steps # Evaluate and save checkpoint log.info('Evaluating at step {}...'.format(step)) ema.assign(model) results, pred_dict = evaluate(model, dev_loader, device, args.dev_eval_file, args.max_ans_len, args.use_squad_v2) saver.save(step, model, results[args.metric_name], device) ema.resume(model) # Log to console results_str = ', '.join('{}: {:05.2f}'.format(k, v) for k, v in results.items()) log.info('Dev {}'.format(results_str)) # Log to TensorBoard log.info('Visualizing in TensorBoard...') for k, v in results.items(): tbx.add_scalar('dev/{}'.format(k), v, step) util.visualize(tbx, pred_dict=pred_dict, eval_path=args.dev_eval_file, step=step, split='dev', num_visuals=args.num_visuals)
def main(args): # Set up logging and devices args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True) log = util.get_logger(args.save_dir, args.name) tbx = SummaryWriter(args.save_dir) device, args.gpu_ids = util.get_available_devices() log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}') args.batch_size *= max(1, len(args.gpu_ids)) # Set random seed log.info(f'Using random seed {args.seed}...') random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vectors = util.torch_from_json(args.char_emb_file) # Get model log.info('Building model...') """ model = BiDAF(word_vectors=word_vectors, hidden_size=args.hidden_size, drop_prob=args.drop_prob) """ ''' model = charBiDAF(word_vectors=word_vectors, char_vectors=char_vectors, emb_size=char_vectors.size(1), hidden_size=args.hidden_size, drop_prob=args.drop_prob) ''' model = QANet(word_vectors=word_vectors, char_vectors=char_vectors, emb_size=char_vectors.size(1), hidden_size=args.hidden_size, drop_prob=args.drop_prob) print( "YOU ARE ENTERING THE QA NET MODEL TRAINING_____________________________________________________________________" ) model = nn.DataParallel(model, args.gpu_ids) if args.load_path: log.info(f'Loading checkpoint from {args.load_path}...') model, step = util.load_model(model, args.load_path, args.gpu_ids) else: step = 0 model = model.to(device) model.train() #print("Model status - ", cuda_or_cpu(model)) ema = util.EMA(model, args.ema_decay) # Get saver saver = util.CheckpointSaver(args.save_dir, max_checkpoints=args.max_checkpoints, metric_name=args.metric_name, maximize_metric=args.maximize_metric, log=log) # Get optimizer and scheduler optimizer = optim.Adadelta(model.parameters(), args.lr, weight_decay=args.l2_wd) scheduler = sched.LambdaLR(optimizer, lambda s: 1.) # Constant LR # Get data loader log.info('Building dataset...') train_dataset = SQuAD(args.train_record_file, args.use_squad_v2) #torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 300 train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2) dev_loader = data.DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn) # Train log.info('Training...') steps_till_eval = args.eval_steps epoch = step // len(train_dataset) while epoch != args.num_epochs: epoch += 1 log.info(f'Starting epoch {epoch}...') with torch.enable_grad(), \ tqdm(total=len(train_loader.dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader: # Setup for forward cw_idxs = cw_idxs.to(device) cc_idxs = cc_idxs.to(device) qw_idxs = qw_idxs.to(device) qc_idxs = qc_idxs.to(device) batch_size = cw_idxs.size(0) optimizer.zero_grad() # Forward log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) y1, y2 = y1.to(device), y2.to(device) loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2) loss_val = loss.item() # Backward loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() #print(torch.cuda.memory_allocated(device=None)) scheduler.step(step // batch_size) ema(model, step // batch_size) #torch.cuda.empty_cache() # Log info step += batch_size progress_bar.update(batch_size) progress_bar.set_postfix(epoch=epoch, NLL=loss_val) tbx.add_scalar('train/NLL', loss_val, step) tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'], step) steps_till_eval -= batch_size if steps_till_eval <= 0: steps_till_eval = args.eval_steps # Evaluate and save checkpoint log.info(f'Evaluating at step {step}...') # print('Memory 1: ', torch.cuda.memory_allocated()) ema.assign(model) # print('Memory 2: ', torch.cuda.memory_allocated()) results, pred_dict = evaluate(model, dev_loader, device, args.dev_eval_file, args.max_ans_len, args.use_squad_v2) saver.save(step, model, results[args.metric_name], device) ema.resume(model) # Log to console results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items()) log.info(f'Dev {results_str}') # Log to TensorBoard log.info('Visualizing in TensorBoard...') for k, v in results.items(): tbx.add_scalar(f'dev/{k}', v, step) util.visualize(tbx, pred_dict=pred_dict, eval_path=args.dev_eval_file, step=step, split='dev', num_visuals=args.num_visuals)
def train(config): from models import QANet with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.train_eval_file, "r") as fh: train_eval_file = json.load(fh) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) with open(config.dev_meta, "r") as fh: meta = json.load(fh) train_log = open(config.train_log, "w") dev_total = meta["total"] print("Building model...") train_dataset = SQuADDataset(config.train_record_file, config.num_steps, config.batch_size) dev_dataset = SQuADDataset(config.dev_record_file, config.val_num_batches, config.batch_size) lr = config.learning_rate model = QANet(word_mat, char_mat).to(device) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(betas=(0.8, 0.999), eps=1e-7, weight_decay=3e-7, params=parameters) crit = lr / math.log2(1000) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: crit * math.log2(ee + 1) if ee + 1 <= 1000 else lr) for ep in tqdm(range(config.num_steps), total=config.num_steps): (Cwid, Ccid, Qwid, Qcid, y1, y2, ids) = train_dataset[ep] Cwid, Ccid, Qwid, Qcid = Cwid.to(device), Ccid.to(device), Qwid.to( device), Qcid.to(device) p1, p2 = model(Cwid, Ccid, Qwid, Qcid) y1, y2 = y1.to(device), y2.to(device) loss1 = F.cross_entropy(p1, y1) loss2 = F.cross_entropy(p2, y1) loss = loss1 + loss2 loss.backward(retain_graph=True) scheduler.step() model.zero_grad() if (ep + 1) % config.checkpoint == 0: del Cwid, Ccid, Qwid, Qcid, y1, y2, p1, p2, loss torch.cuda.empty_cache() metric = evaluate_batch(model, dev_eval_file, dev_dataset) log_ = "EPOCH {:8d} loss {:8f} F1 {:8f} EM {:8f}\n".format( ep, metric["loss"], metric["f1"], metric["exact_match"]) train_log.write(log_) train_log.flush() dev_f1 = metric["f1"] dev_em = metric["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_em = max(best_em, dev_em) best_f1 = max(best_f1, dev_f1) fn = os.path.join(config.save_dir, "model_{}.ckpt".format(ep)) torch.save(model, fn, pickle_protocol=False)
def train(config): from models import QANet with open(config.word_emb_file, "r") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "r") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.train_eval_file, "r") as fh: train_eval_file = json.load(fh) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("Building model...") model = QANet(word_mat, char_mat).to(device) train_dataset = SQuADDataset(config.train_record_file, config.num_steps, config.batch_size) dev_dataset = SQuADDataset(config.dev_record_file, -1, config.batch_size, name='dev') lr = config.learning_rate base_lr = 1.0 lr_warm_up_num = config.lr_warm_up_num parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(lr=base_lr, betas=(config.beta1, config.beta2), eps=1e-7, weight_decay=3e-7, params=parameters) cr = lr / math.log2(lr_warm_up_num) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * math.log2(ee + 1) if ee < lr_warm_up_num else lr) L = config.checkpoint N = config.num_steps best_f1 = 0 best_em = 0 patience = 0 unused = True for iter in tqdm(range(0, N)): if iter % L == 0: valid_train_loss, valid_train_metrics = valid(model, train_dataset, train_eval_file, num_ex=1000) valid_dev_loss, valid_dev_metrics = valid(model, dev_dataset, dev_eval_file, num_ex=1000) if config.use_tensorboard: writer.add_scalar('data/valid_train_loss', valid_train_loss, iter / L) writer.add_scalar('data/valid_dev_loss', valid_dev_loss, iter / L) writer.add_scalar('data/valid_train_em', valid_train_metrics['exact_match'], iter / L) writer.add_scalar('data/valid_dev_em', valid_dev_metrics['exact_match'], iter / L) writer.add_scalar('data/valid_train_f1', valid_train_metrics['f1'], iter / L) writer.add_scalar('data/valid_dev_f1', valid_dev_metrics['f1'], iter / L) train_loss = update(model, optimizer, scheduler, train_dataset[iter]) if config.use_tensorboard: writer.add_scalar('data/train_loss', train_loss, iter) if iter + L >= lr_warm_up_num - 1 and unused: optimizer.param_groups[0]['initial_lr'] = lr scheduler = optim.lr_scheduler.ExponentialLR( optimizer, config.decay) unused = False #print("Learning rate: {}".format(scheduler.get_lr())) ''' dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] print('after {} steps , f1={} em={}'.format(iter, dev_f1, dev_em)) if dev_em < best_em: patience += 1 print('doesnot beat best model, patience={}'.format(patience)) if patience > config.early_stop: break else: fn = os.path.join(config.save_dir, "model.pt") result = os.path.join(config.save_dir, "best_result.txt") print('beat best model, now best em={}, f1={}'.format(dev_em, dev_f1)) with open(result, "w") as f: json.dump(metrics, f) best_em = dev_em best_f1 = max(best_f1, dev_f1) torch.save(model, fn) patience = 0 ''' writer.close()
def main(args): # Set up logging args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False) log = util.get_logger(args.save_dir, args.name) log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True))) device, gpu_ids = util.get_available_devices() args.batch_size *= max(1, len(gpu_ids)) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vectors = util.torch_from_json(args.char_emb_file) # Get model log.info('Building model...') model = QANet(word_vectors, char_vectors, args.test_para_limit, args.test_ques_limit, args.f_model, num_head=args.num_head, train_cemb=(not args.pretrained_char)) model = nn.DataParallel(model, gpu_ids) log.info('Loading checkpoint from {}...'.format(args.load_path)) model = util.load_model(model, args.load_path, gpu_ids, return_step=False) model = model.to(device) model.eval() # Get data loader log.info('Building dataset...') record_file = vars(args)['{}_record_file'.format(args.split)] dataset = SQuAD(record_file, args.use_squad_v2) data_loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Evaluate log.info('Evaluating on {} split...'.format(args.split)) loss_f = torch.nn.CrossEntropyLoss() nll_meter = util.AverageMeter() pred_dict = {} # Predictions for TensorBoard sub_dict = {} # Predictions for submission eval_file = vars(args)['{}_eval_file'.format(args.split)] with open(eval_file, 'r') as fh: gold_dict = json_load(fh) with torch.no_grad(), \ tqdm(total=len(dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) cc_idxs = cc_idxs.to(device) qc_idxs = qc_idxs.to(device) batch_size = cw_idxs.size(0) # Forward log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) y1, y2 = y1.to(device), y2.to(device) loss = torch.mean(loss_f(log_p1, y1) + loss_f(log_p2, y2)) loss_val = loss.item() # Get F1 and EM scores p1, p2 = log_p1.exp(), log_p2.exp() starts, ends = util.discretize(p1, p2, args.max_ans_len, args.use_squad_v2) # Log info progress_bar.update(batch_size) if args.split != 'test': # No labels for the test set, so NLL would be invalid progress_bar.set_postfix(NLL=nll_meter.avg) idx2pred, uuid2pred = util.convert_tokens(gold_dict, ids.tolist(), starts.tolist(), ends.tolist(), args.use_squad_v2) pred_dict.update(idx2pred) sub_dict.update(uuid2pred) # Log results (except for test set, since it does not come with labels) if args.split != 'test': results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2) results_list = [('NLL', nll_meter.avg), ('F1', results['F1']), ('EM', results['EM'])] if args.use_squad_v2: results_list.append(('AvNA', results['AvNA'])) results = OrderedDict(results_list) # Log to console results_str = ', '.join('{}: {:05.2f}'.format(k, v) for k, v in results.items()) log.info('{} {}'.format(args.split.title(), results_str)) # Log to TensorBoard tbx = SummaryWriter(args.save_dir) util.visualize(tbx, pred_dict=pred_dict, eval_path=eval_file, step=0, split=args.split, num_visuals=args.num_visuals) # Write submission file sub_path = join(args.save_dir, args.split + '_' + args.sub_file) log.info('Writing submission file to {}...'.format(sub_path)) with open(sub_path, 'w') as csv_fh: csv_writer = csv.writer(csv_fh, delimiter=',') csv_writer.writerow(['Id', 'Predicted']) for uuid in sorted(sub_dict): csv_writer.writerow([uuid, sub_dict[uuid]])
def main(args): # Set up logging and devices args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True) log = util.get_logger(args.save_dir, args.name) tbx = SummaryWriter(args.save_dir) device, args.gpu_ids = util.get_available_devices() log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}') args.batch_size *= max(1, len(args.gpu_ids)) # Set random seed log.info(f'Using random seed {args.seed}...') random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get embeddings log.info('Loading embeddings...') word_vectors = util.torch_from_json(args.word_emb_file) char_vec = util.torch_from_json(args.char_emb_file) # Get model log.info('Building model...') if args.name == 'baseline': model = BiDAF(word_vectors=word_vectors, hidden_size=args.hidden_size, drop_prob=args.drop_prob) elif args.name == 'charembeddings': model = BiDAFChar(word_vectors=word_vectors, char_vec=char_vec, word_len=16, hidden_size=args.hidden_size, drop_prob=args.drop_prob) elif args.name == 'charembeddings2': model = BiDAFChar2(word_vectors=word_vectors, char_vec=char_vec, word_len=16, hidden_size=args.hidden_size, drop_prob=args.drop_prob) elif args.name == 'qanet': model = QANet(word_vectors=word_vectors, char_vec=char_vec, word_len=16, emb_size=args.hidden_size, drop_prob=args.drop_prob, enc_size=args.enc_size, n_head=args.n_head, LN_train=args.ln_train, DP_residual=args.dp_res, mask_pos=args.mask_pos, two_pos=args.two_pos, total_prob=args.total_drop, final_prob=args.final_prob) elif args.name == 'qanet2': model = QANet2(word_vectors=word_vectors, char_vec=char_vec, word_len=16, emb_size=args.hidden_size, drop_prob=args.drop_prob, enc_size=args.enc_size, n_head=args.n_head, LN_train=args.ln_train, DP_residual=args.dp_res, mask_pos=args.mask_pos, two_pos=args.two_pos, rel=args.rel_att, total_prob=args.total_drop, final_prob=args.final_prob, freeze=args.freeze_emb) elif args.name == 'qanet3': model = QANet3(word_vectors=word_vectors, char_vec=char_vec, word_len=16, emb_size=args.hidden_size, drop_prob=args.drop_prob, enc_size=args.enc_size, n_head=args.n_head, LN_train=args.ln_train, DP_residual=args.dp_res, mask_pos=args.mask_pos, two_pos=args.two_pos, rel=args.rel_att, total_prob=args.total_drop, final_prob=args.final_prob, freeze=args.freeze_emb) elif args.name == 'qanet4': model = QANet4(word_vectors=word_vectors, char_vec=char_vec, word_len=16, emb_size=args.hidden_size, drop_prob=args.drop_prob, enc_size=args.enc_size, n_head=args.n_head, LN_train=args.ln_train, DP_residual=args.dp_res, mask_pos=args.mask_pos, two_pos=args.two_pos, rel=args.rel_att, total_prob=args.total_drop, final_prob=args.final_prob, freeze=args.freeze_emb) else: raise ValueError('Wrong model name') model = nn.DataParallel(model, args.gpu_ids) if args.load_path: log.info(f'Loading checkpoint from {args.load_path}...') model, step = util.load_model(model, args.load_path, args.gpu_ids) else: step = 0 model = model.to(device) model.train() ema = util.EMA(model, args.ema_decay) # Get saver saver = util.CheckpointSaver(args.save_dir, max_checkpoints=args.max_checkpoints, metric_name=args.metric_name, maximize_metric=args.maximize_metric, log=log) # Get optimizer and scheduler if args.name == 'qanet': optimizer = optim.Adam(model.parameters(), args.lr, betas=(0.8, 0.999), weight_decay=3 * 1e-7, eps=1e-7) scheduler = warmup(optimizer, 1, 2000) elif args.opt == 'adam': if args.grad_cent: optimizer = AdamWGC(model.parameters(), args.lr, betas=(0.9, 0.999), weight_decay=3 * 1e-7, eps=1e-7, use_gc=True) else: optimizer = AdamW(model.parameters(), args.lr, betas=(0.8, 0.999), weight_decay=3 * 1e-7, eps=1e-7) scheduler = warmup(optimizer, 1, 2000) elif args.opt == 'adadelta': optimizer = optim.Adadelta(model.parameters(), args.lr, weight_decay=3 * 1e-7) scheduler = sched.LambdaLR(optimizer, lambda s: 1.) # Constant LR elif args.opt == 'sgd': optimizer = optim.SGD(model.parameters(), args.lr, weight_decay=3 * 1e-7) scheduler = sched.LambdaLR(optimizer, lambda s: 1.) # Constant LR # Get data loader log.info('Building dataset...') train_dataset = SQuAD(args.train_record_file, args.use_squad_v2) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2) dev_loader = data.DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) # Train log.info('Training...') steps_till_eval = args.eval_steps epoch = step // len(train_dataset) i = 0 while epoch != args.num_epochs: epoch += 1 log.info(f'Starting epoch {epoch}...') with torch.enable_grad(), \ tqdm(total=len(train_loader.dataset)) as progress_bar: for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader: # Setup for forward cw_idxs = cw_idxs.to(device) qw_idxs = qw_idxs.to(device) batch_size = cw_idxs.size(0) # Forward log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs) y1, y2 = y1.to(device), y2.to(device) loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2) loss_val = loss.item() i += 1 loss /= args.acc_step # Backward loss.backward() if i % args.acc_step == 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step(i // (args.acc_step)) ema(model, i // (args.acc_step)) optimizer.zero_grad() # Log info step += batch_size progress_bar.update(batch_size) progress_bar.set_postfix(epoch=epoch, NLL=loss_val) tbx.add_scalar('train/NLL', loss_val, step) tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'], step) steps_till_eval -= batch_size if steps_till_eval <= 0 and i % args.acc_step == 0: steps_till_eval = args.eval_steps # Evaluate and save checkpoint log.info(f'Evaluating at step {step}...') ema.assign(model) results, pred_dict = evaluate(model, dev_loader, device, args.dev_eval_file, args.max_ans_len, args.use_squad_v2) saver.save(step, model, results[args.metric_name], device) ema.resume(model) # Log to console results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items()) log.info(f'Dev {results_str}') # Log to TensorBoard log.info('Visualizing in TensorBoard...') for k, v in results.items(): tbx.add_scalar(f'dev/{k}', v, step) util.visualize(tbx, pred_dict=pred_dict, eval_path=args.dev_eval_file, step=step, split='dev', num_visuals=args.num_visuals)
def train_entry(config): from models import QANet with open(config.word_emb_file, "rb") as fh: word_mat = np.array(json.load(fh), dtype=np.float32) with open(config.char_emb_file, "rb") as fh: char_mat = np.array(json.load(fh), dtype=np.float32) with open(config.dev_eval_file, "r") as fh: dev_eval_file = json.load(fh) print("Building model...") train_dataset = get_loader(config.train_record_file, config.batch_size) dev_dataset = get_loader(config.dev_record_file, config.batch_size) lr = config.learning_rate base_lr = 1 lr_warm_up_num = config.lr_warm_up_num model = QANet(word_mat, char_mat).to(device) if torch.cuda.device_count() > 1: print('i can use gpu') model = torch.nn.DataParallel(model, device_ids=[0, 1]) model.load_state_dict( torch.load('/home/cn/AI/QANet-pytorch-/model_state_dict.pt')) ema = EMA(config.decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = optim.Adam(lr=base_lr, betas=(0.9, 0.999), eps=1e-7, weight_decay=5e-8, params=parameters) cr = lr / math.log2(lr_warm_up_num) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * math.log2(ee + 1) if ee < lr_warm_up_num else lr) best_f1 = 0 best_em = 0 patience = 0 unused = False for iter in range(config.num_epoch): train(model, optimizer, scheduler, train_dataset, dev_dataset, dev_eval_file, iter, ema) print(iter) ema.assign(model) metrics = test(model, dev_dataset, dev_eval_file, (iter + 1) * len(train_dataset)) dev_f1 = metrics["f1"] dev_em = metrics["exact_match"] if dev_f1 < best_f1 and dev_em < best_em: patience += 1 if patience > config.early_stop: break else: patience = 0 best_f1 = max(best_f1, dev_f1) best_em = max(best_em, dev_em) fn = os.path.join(config.save_dir, "model.pt") torch.save(model, fn) torch.save(model.state_dict(), 'model_state_dict.pt') ema.resume(model)