def train_model(train_stage, save_model_path, model, optimizer, epochs, train_dataset, eval_dataset, batch_size=1, gradient_accumulation_steps=1, num_workers=1): print('Start Training'.center(60, '*')) training_dataloader = data.DataLoader( dataset=train_dataset, collate_fn=TextCollate(train_dataset), batch_size=batch_size, num_workers=num_workers, shuffle=True) for epoch in range(1, epochs + 1): print('train epoch: ' + str(epoch)) avg_loss = train_epoch(train_stage, model, optimizer, training_dataloader, gradient_accumulation_steps, epoch) print('average_loss:{}'.format(avg_loss)) eval_model(train_stage, model, eval_dataset, batch_size=batch_size, num_workers=num_workers) save_model(save_model_path, model, epoch)
def eval_model(train_stage, model, dataset, batch_size=1, num_workers=1): global global_step global debug_break model.eval() dataloader = data.DataLoader(dataset=dataset, collate_fn=TextCollate(dataset), batch_size=batch_size, num_workers=num_workers, shuffle=False) total_loss = 0.0 correct_sum = 0 proc_sum = 0 num_batch = dataloader.__len__() print('Evaluating Model...') for step, batch in enumerate(dataloader): tokens = batch['tokens'].to(device) segment_ids = batch['segment_ids'].to(device) attn_masks = batch['attn_masks'].to(device) labels = batch['labels'].to(device) with torch.no_grad(): loss, logits = model(tokens, token_type_ids=segment_ids, attention_mask=attn_masks, labels=labels, training_stage=train_stage, inference=False) loss = loss.mean() loss_val = loss.item() total_loss += loss_val if debug_break and step > 50: break if train_stage == 0: _, top_index = logits.topk(1) correct_sum += (top_index.view(-1) == labels).sum().item() proc_sum += labels.shape[0] print('eval total avg loss: {}'.format(total_loss / num_batch)) if train_stage == 0: print("Correct Prediction: " + str(correct_sum)) print("Accuracy Rate: " + format(correct_sum / proc_sum, "0.4f"))
def train_model(train_stage, save_model_path, master_gpu_id, model, optimizer, epochs, train_dataset, eval_dataset, batch_size=1, gradient_accumulation_steps=1, use_cuda=False, num_workers=1): logging.info("Start Training".center(60, "=")) training_dataloader = data.DataLoader( dataset=train_dataset, collate_fn=TextCollate(train_dataset), pin_memory=use_cuda, batch_size=batch_size, num_workers=num_workers, shuffle=True) for epoch in range(1, epochs + 1): logging.info("Training Epoch: " + str(epoch)) avg_loss = train_epoch(train_stage, master_gpu_id, model, optimizer, training_dataloader, gradient_accumulation_steps, use_cuda) logging.info("Average Loss: " + format(avg_loss, "0.4f")) eval_model(train_stage, master_gpu_id, model, eval_dataset, batch_size=batch_size, use_cuda=use_cuda, num_workers=num_workers) save_model(save_model_path, model, epoch)
def eval_model(train_stage, master_gpu_id, model, dataset, batch_size=1, use_cuda=False, num_workers=1): global global_step global debug_break model.eval() dataloader = data.DataLoader(dataset=dataset, collate_fn=TextCollate(dataset), pin_memory=use_cuda, batch_size=batch_size, num_workers=num_workers, shuffle=False) total_loss = 0.0 correct_sum = 0 proc_sum = 0 num_sample = dataloader.dataset.__len__() num_batch = dataloader.__len__() predicted_probs = [] true_labels = [] logging.info("Evaluating Model...") infos = [] for step, batch in enumerate(tqdm(dataloader, unit="batch", ncols=100, desc="Evaluating process: ")): texts = batch["texts"] tokens = batch["tokens"].cuda(master_gpu_id) if use_cuda else batch["tokens"] segment_ids = batch["segment_ids"].cuda(master_gpu_id) if use_cuda else batch["segment_ids"] attn_masks = batch["attn_masks"].cuda(master_gpu_id) if use_cuda else batch["attn_masks"] labels = batch["labels"].cuda(master_gpu_id) if use_cuda else batch["labels"] with torch.no_grad(): loss, logits = model(tokens, token_type_ids=segment_ids, attention_mask=attn_masks, labels=labels, training_stage=train_stage, inference=False) loss = loss.mean() loss_val = loss.item() total_loss += loss_val #writer.add_scalar('eval/loss', total_loss/num_batch, global_step) if debug_break and step > 50: break if train_stage == 0: _, top_index = logits.topk(1) correct_sum += (top_index.view(-1) == labels).sum().item() proc_sum += labels.shape[0] logging.info('eval total avg loss:%s', format(total_loss/num_batch, "0.4f")) if train_stage == 0: logging.info("Correct Prediction: " + str(correct_sum)) logging.info("Accuracy Rate: " + format(correct_sum / proc_sum, "0.4f"))
def infer_model(master_gpu_id, model, dataset, use_cuda=False, num_workers=1, inference_speed=None, dump_info_file=None): global global_step global debug_break model.eval() infer_dataloader = data.DataLoader(dataset=dataset, collate_fn=TextCollate(dataset), pin_memory=use_cuda, batch_size=1, num_workers=num_workers, shuffle=False) correct_sum = 0 num_sample = infer_dataloader.dataset.__len__() predicted_probs = [] true_labels = [] infos = [] logging.info("Inference Model...") cnt = 0 stime_all = time.time() for step, batch in enumerate( tqdm(infer_dataloader, unit="batch", ncols=100, desc="Inference process: ")): texts = batch["texts"] tokens = batch["tokens"].cuda( master_gpu_id) if use_cuda else batch["tokens"] segment_ids = batch["segment_ids"].cuda( master_gpu_id) if use_cuda else batch["segment_ids"] attn_masks = batch["attn_masks"].cuda( master_gpu_id) if use_cuda else batch["attn_masks"] labels = batch["labels"].cuda( master_gpu_id) if use_cuda else batch["labels"] with torch.no_grad(): probs, layer_idxes, uncertain_infos = model( tokens, token_type_ids=segment_ids, attention_mask=attn_masks, inference=True, inference_speed=inference_speed) _, top_index = probs.topk(1) correct_sum += (top_index.view(-1) == labels).sum().item() cnt += 1 if cnt == 1: stime = time.time() if dump_info_file != None: for label, pred, prob, layer_i, text in zip( labels, top_index.view(-1), probs, [layer_idxes], texts): infos.append((label.item(), pred.item(), prob.cpu().numpy(), layer_i, text)) if debug_break and step > 50: break time_per = (time.time() - stime) / (cnt - 1) time_all = time.time() - stime_all acc = format(correct_sum / num_sample, "0.4f") logging.info("speed_arg:%s, time_per_record:%s, acc:%s, total_time:%s", inference_speed, format(time_per, '0.4f'), acc, format(time_all, '0.4f')) if dump_info_file != None and len(dump_info_file) != 0: with open(dump_info_file, 'w') as fw: for label, pred, prob, layer_i, text in infos: fw.write('\t'.join([str(label), str(pred), str(layer_i), text]) + '\n') if probs.shape[1] == 2: labels_pr = [info[0] for info in infos] preds_pr = [info[1] for info in infos] precise, recall = eval_pr(labels_pr, preds_pr) logging.info("precise:%s, recall:%s", format(precise, '0.4f'), format(recall, '0.4f'))