# compute the accuracy acc_t = compute_accuracy(y_pred, batch_dict['y_target'], output_type=args.output_type) running_acc += (acc_t - running_acc) / (batch_index + 1) val_bar.set_postfix(loss=running_loss, acc=running_acc, epoch=epoch_index) val_bar.update() train_state['val_loss'].append(running_loss) train_state['val_acc'].append(running_acc) train_state = update_train_state(args=args, model=classifier, train_state=train_state) scheduler.step(train_state['val_loss'][-1]) train_bar.n = 0 val_bar.n = 0 epoch_bar.update() if train_state['stop_early']: break train_bar.n = 0 val_bar.n = 0 epoch_bar.update()
def main(args): set_envs(args) print("Using: {}".format(args.device)) total_items = os.popen(f'ls -l {args.hotpotQA_item_folder} |grep "^-"| wc -l').readlines()[0].strip() total_items = int(total_items) print(f"total items: {total_items}") # 实例化 classifier = GAT_HotpotQA(features=args.features, hidden=args.hidden, nclass=args.nclass, dropout=args.dropout, alpha=args.alpha, nheads=args.nheads, nodes_num=args.pad_max_num) class_weights_sent, class_weights_para, class_weights_Qtype = \ HotpotQA_GNN_Dataset.get_weights(device=args.device) loss_func_sent = nn.CrossEntropyLoss(class_weights_sent,ignore_index=-100) loss_func_para = nn.CrossEntropyLoss(class_weights_para,ignore_index=-100) loss_func_Qtype = nn.CrossEntropyLoss(class_weights_Qtype) optimizer = optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr=args.learning_rate) # Initialization opt_level = 'O1' if args.cuda: classifier = classifier.to(args.device) if args.fp16: classifier, optimizer = amp.initialize(classifier, optimizer, opt_level=opt_level) classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=args.device_ids, output_device=0, find_unused_parameters=True) if args.reload_from_files: checkpoint = torch.load(args.model_state_file) classifier.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if args.fp16: amp.load_state_dict(checkpoint['amp']) train_state = make_train_state(args) try: writer = SummaryWriter(log_dir=args.log_dir, flush_secs=args.flush_secs) cursor_train = 0 cursor_val = 0 if args.reload_from_files and 'cursor_train' in checkpoint.keys(): cursor_train = checkpoint['cursor_train'] + 1 cursor_val = checkpoint['cursor_val'] + 1 # for epoch_index in range(args.num_epochs): if args.chunk_size < 0: args.chunk_size = total_items for chunk_i in range(0, total_items, args.chunk_size): dataset = HotpotQA_GNN_Dataset.build_dataset(hotpotQA_item_folder = args.hotpotQA_item_folder, i_from = chunk_i, i_to = chunk_i+args.chunk_size, seed=args.seed+chunk_i) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.7, patience=dataset.get_num_batches(args.batch_size)/10) dataset.set_parameters(args.pad_max_num, args.pad_value) epoch_bar = tqdm(desc='training routine', total=args.num_epochs, position=0) dataset.set_split('train') train_bar = tqdm(desc='split=train', total=dataset.get_num_batches(args.batch_size), position=1) dataset.set_split('val') val_bar = tqdm(desc='split=val', total=dataset.get_num_batches(args.batch_size), position=1) for epoch_index in range(args.num_epochs): train_bar.n = 0 val_bar.n = 0 train_state['epoch_index'] = epoch_index dataset.set_split('train') batch_generator = gen_GNN_batches(dataset, batch_size=args.batch_size, device=args.device) running_loss = 0.0 running_acc_Qtype = 0.0 running_acc_topN = 0.0 classifier.train() for batch_index, batch_dict in enumerate(batch_generator): optimizer.zero_grad() logits_sent, logits_para, logits_Qtype = \ classifier(batch_dict['feature_matrix'], batch_dict['adj']) # topN sents max_value, max_index = logits_sent.max(dim=-1) # max_index is predict class. topN_sent_index_batch = (max_value * batch_dict['sent_mask'].squeeze()).topk(args.topN_sents, dim=-1)[1] topN_sent_predict = torch.gather(max_index, -1, topN_sent_index_batch) topN_sent_label = torch.gather((batch_dict['labels'] * batch_dict['sent_mask']).squeeze(), -1, topN_sent_index_batch) logits_sent = (batch_dict['sent_mask'] * logits_sent).view(-1,2) labels_sent = (batch_dict['sent_mask']*batch_dict['labels'] + \ batch_dict['sent_mask'].eq(0)*-100).view(-1) logits_para = (batch_dict['para_mask'] * logits_para).view(-1,2) labels_para = (batch_dict['para_mask']*batch_dict['labels'] + \ batch_dict['para_mask'].eq(0)*-100).view(-1) loss_sent = loss_func_sent(logits_sent, labels_sent) # [B,2] [B] loss_para = loss_func_para(logits_para, labels_para) # [B,2] [B] loss_Qtype = loss_func_Qtype(logits_Qtype.view(-1,2), batch_dict['answer_type'].view(-1)) # [B,2] [B] loss = loss_sent + loss_para + loss_Qtype running_loss += (loss.item() - running_loss) / (batch_index + 1) if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() scheduler.step(running_loss) # compute the recall recall_t_sent = compute_recall(logits_sent.view(-1,2), batch_dict['labels'].view(-1), batch_dict['sent_mask'].view(-1)) recall_t_para = compute_recall(logits_para.view(-1,2), batch_dict['labels'].view(-1), batch_dict['para_mask'].view(-1)) # compute the acc acc_t_Qtype = compute_accuracy(logits_Qtype.view(-1,2), batch_dict['answer_type'].view(-1)) running_acc_Qtype += (acc_t_Qtype - running_acc_Qtype) / (batch_index + 1) acc_t_topN = compute_accuracy(predict=topN_sent_predict.view(-1), labels=topN_sent_label.view(-1)) running_acc_topN += (acc_t_topN - running_acc_topN) / (batch_index + 1) # update bar train_bar.set_postfix(loss=running_loss,epoch=epoch_index) train_bar.update() writer.add_scalar('loss/train', loss.item(), cursor_train) writer.add_scalar('recall_t_sent/train', recall_t_sent, cursor_train) writer.add_scalar('recall_t_para/train', recall_t_para, cursor_train) writer.add_scalar('running_acc_Qtype/train', running_acc_Qtype, cursor_train) writer.add_scalar('running_acc_topN/train', running_acc_topN, cursor_train) writer.add_scalar('running_loss/train', running_loss, cursor_train) cursor_train += 1 train_state['train_running_loss'].append(running_loss) # Iterate over val dataset # setup: batch generator, set loss and acc to 0; set eval mode on dataset.set_split('val') batch_generator = gen_GNN_batches(dataset, batch_size=args.batch_size, device=args.device) running_loss = 0.0 running_acc_Qtype = 0.0 running_acc_topN = 0.0 classifier.eval() for batch_index, batch_dict in enumerate(batch_generator): # compute the output with torch.no_grad(): logits_sent, logits_para, logits_Qtype = \ classifier(batch_dict['feature_matrix'], batch_dict['adj']) max_value, max_index = logits_sent.max(dim=-1) # max_index is predict class. topN_sent_index_batch = (max_value * batch_dict['sent_mask'].squeeze()).topk(args.topN_sents, dim=-1)[1] topN_sent_predict = torch.gather(max_index, -1, topN_sent_index_batch) topN_sent_label = torch.gather((batch_dict['labels'] * batch_dict['sent_mask']).squeeze(), -1, topN_sent_index_batch) logits_sent = (batch_dict['sent_mask'] * logits_sent).view(-1,2) labels_sent = (batch_dict['sent_mask']*batch_dict['labels'] + \ batch_dict['sent_mask'].eq(0)*-100).view(-1) logits_para = (batch_dict['para_mask'] * logits_para).view(-1,2) labels_para = (batch_dict['para_mask']*batch_dict['labels'] + \ batch_dict['para_mask'].eq(0)*-100).view(-1) loss_sent = loss_func_sent(logits_sent, labels_sent) # [B,2] [B] loss_para = loss_func_para(logits_para, labels_para) # [B,2] [B] loss_Qtype = loss_func_Qtype(logits_Qtype.view(-1,2), batch_dict['answer_type'].view(-1)) # [B,2] [B] loss = loss_sent + loss_para + loss_Qtype try: running_loss += (loss.item() - running_loss) / (batch_index + 1) except RuntimeError: print(f"err batch: {batch_index}") print_exc() exit() # compute the recall recall_t_sent = compute_recall(logits_sent.view(-1,2), batch_dict['labels'].view(-1), batch_dict['sent_mask'].view(-1)) recall_t_para = compute_recall(logits_para.view(-1,2), batch_dict['labels'].view(-1), batch_dict['para_mask'].view(-1)) # compute the acc acc_t_Qtype = compute_accuracy(logits_Qtype.view(-1,2), batch_dict['answer_type'].view(-1)) running_acc_Qtype += (acc_t_Qtype - running_acc_Qtype) / (batch_index + 1) acc_t_topN = compute_accuracy(predict=topN_sent_predict.view(-1), labels=topN_sent_label.view(-1)) running_acc_topN += (acc_t_topN - running_acc_topN) / (batch_index + 1) # update bar val_bar.set_postfix(loss=running_loss,epoch=epoch_index) val_bar.update() writer.add_scalar('loss/val', loss.item(), cursor_val) writer.add_scalar('recall_t_sent/val', recall_t_sent, cursor_val) writer.add_scalar('recall_t_para/val', recall_t_para, cursor_val) writer.add_scalar('running_acc_Qtype/val', running_acc_Qtype, cursor_val) writer.add_scalar('running_acc_topN/val', running_acc_topN, cursor_val) writer.add_scalar('running_loss/val', running_loss, cursor_val) cursor_val += 1 train_state['val_running_loss'].append(running_loss) train_state = update_train_state(args=args, model=classifier, optimizer = optimizer, train_state=train_state) epoch_bar.update() # if train_state['stop_early']: # print('STOP EARLY!') # exit() # epoch done. # chunk done. # all finished. except KeyboardInterrupt: print("Exiting loop") except : print(f"err in chunk {chunk_i}, epoch_index {epoch_index}, batch_index {batch_index}.") print_exc()
def train(model, train_dataset, val_dataset, args: Namespace): """ :param model: model pytorch :param train_dataset: :param val_dataset: :param args: :return: """ checkpoint_path = args.model_dir + '/best_model.pth' if args.checkpoint and os.path.exists(checkpoint_path): print("load model from " + checkpoint_path) model.load_state_dict(torch.load(checkpoint_path)) if os.path.exists(args.model_dir) is False: os.mkdir(args.model_dir) device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") print("device: ", device) train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3) # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=args.num_epochs) train_state = make_train_state(args) for epoch in range(args.num_epochs): start_time = time.time() print(32 * '_' + ' ' * 9 + 31 * '_') print(32 * '_' + 'EPOCH {:3}'.format(epoch + 1) + 31 * '_') train_state['epoch_index'] = epoch train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) train_loss = 0. model.train() for i, (x_vector, y_vector, x_mask) in enumerate(train_loader): optimizer.zero_grad() out = model(x_vector.to(device), x_mask.to(device)) loss = model.compute_loss(out, y_vector.to(device), x_mask.to(device)) loss.backward() # gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() train_loss += (loss.item() - train_loss) / (i + 1) if (i + 1) % 10 == 0: print("|{:4}| train_loss = {:.4f}".format(i + 1, loss.item())) # writer.add_scalar('Train/Loss', train_loss, epoch) train_state['train_loss'].append(train_loss) optimizer.zero_grad() val_loss, f1, acc = model.evaluate(val_dataset, thresh=args.thresh, batch_size=args.batch_size) # scheduler.step(val_loss) print('Val loss: {:.4f}'.format(val_loss)) scheduler.step() train_state['val_loss'].append(-f1) ## update train state train_state = update_train_state(model, train_state) if train_state['stop_early']: print('Stop early.......!') break return model
def train(model, train_dataset, val_dataset, args: Namespace): """ :param model: model pytorch :param train_dataset: :param val_dataset: :param args: :return: """ checkpoint_path = args.model_dir + '/last_checkpoint.pth' if args.checkpoint and os.path.exists(checkpoint_path): print("load model from " + checkpoint_path) model.load_state_dict(torch.load(checkpoint_path)) if os.path.exists(args.model_dir) is False: os.mkdir(args.model_dir) device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") print("device: ", device) train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True) model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1) #scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=args.num_epochs) val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True) train_state = make_train_state(args) for epoch in range(args.num_epochs): start_time = time.time() print("Epoch: ", epoch + 1) train_state['epoch_index'] = epoch train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) train_loss = 0. model.train() for i, (x_vector, y_vector, x_mask) in enumerate(train_loader): optimizer.zero_grad() loss, out = model.forward_loss(x_vector.to(device), y_vector.to(device), x_mask.to(device)) loss.backward() # gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() train_loss += (loss.item() - train_loss) / (i + 1) if (i + 1) % 10 == 0: print("\tStep: {} train_loss: {}".format(i + 1, loss.item())) # writer.add_scalar('Train/Loss', train_loss, epoch) train_state['train_loss'].append(train_loss) model.eval() val_loss = 0. y_pred = [] y_true = [] for i, (x_vector, y_vector, x_mask) in enumerate(val_loader): optimizer.zero_grad() loss, out = model.forward_loss(x_vector.to(device), y_vector.to(device), x_mask.to(device)) val_loss += (loss.item() - val_loss) / (i + 1) scheduler.step(val_loss) y_pred.append((out > args.thresh).long()) y_true.append(y_vector.long()) train_state['val_loss'].append(val_loss) train_state = update_train_state(model, train_state) y_true = torch.cat(y_true, dim=-1).cpu().detach().numpy() y_pred = torch.cat(y_pred, dim=-1).cpu().detach().numpy() acc, sub_acc, f1, precision, recall, hamming_loss = get_multi_label_metrics( y_true=y_true, y_pred=y_pred) print('f1: {} precision: {} recall: {}'.format( f1, precision, recall)) print('accuracy: {} sub accuracy : {} hamming loss: {}'.format( acc, sub_acc, hamming_loss)) # save best model. torch.save(model.state_dict(), args.model_dir + '/' + args.model_name)