def val_epoch(self, dataloader, device, step=0, plot=False): ''' Epoch operation in evaluation phase ''' if device == 'cuda': assert self.CUDA_AVAILABLE # Set model and classifier training mode self.model.eval() self.classifier.eval() # use evaluator to calculate the average performance evaluator = Evaluator() pred_list = [] real_list = [] with torch.no_grad(): for batch in tqdm( dataloader, mininterval=5, desc=' - (Evaluation) ', leave=False): # training_data should be a iterable # get data from dataloader feature_1, feature_2, y = parse_data(batch, device) batch_size = len(feature_1) # get logits logits, attn = self.model(feature_1, feature_2) logits = logits.view(batch_size, -1) logits = self.classifier(logits) if self.d_output == 1: pred = logits.sigmoid() loss = mse_loss(pred, y) else: pred = logits loss = cross_entropy_loss(pred, y, smoothing=False) acc = accuracy(pred, y, threshold=self.threshold) precision, recall, _, _ = precision_recall( pred, y, self.d_output, threshold=self.threshold) # feed the metrics in the evaluator evaluator(loss.item(), acc.item(), precision[1].item(), recall[1].item()) '''append the results to the predict / real list for drawing ROC or PR curve.''' if plot: pred_list += pred.tolist() real_list += y.tolist() if plot: area, precisions, recalls, thresholds = pr( pred_list, real_list) plot_pr_curve(recalls, precisions, auc=area) # get evaluation results from the evaluator loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results() self.eval_logger.info( '[EVALUATION] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (step, loss_avg, acc_avg, pre_avg, rec_avg)) self.summary_writer.add_scalar('loss/eval', loss_avg, step) self.summary_writer.add_scalar('acc/eval', acc_avg, step) self.summary_writer.add_scalar('precision/eval', pre_avg, step) self.summary_writer.add_scalar('recall/eval', rec_avg, step) state_dict = self.early_stopping(loss_avg) if state_dict['save']: checkpoint = self.checkpoint(step) self.save_model( checkpoint, self.save_path + '-step-%d_loss-%.5f' % (step, loss_avg)) return state_dict['break']
def train_epoch(self, train_dataloader, eval_dataloader, device, smothing, earlystop): ''' Epoch operation in training phase''' if device == 'cuda': assert self.CUDA_AVAILABLE # Set model and classifier training mode self.model.train() self.classifier.train() total_loss = 0 batch_counter = 0 # update param per batch for batch in tqdm(train_dataloader, mininterval=1, desc=' - (Training) ', leave=False): # training_data should be a iterable # get data from dataloader feature_1, feature_2, y = parse_data(batch, device) batch_size = len(feature_1) # forward self.optimizer.zero_grad() logits, attn = self.model(feature_1, feature_2) logits = logits.view(batch_size, -1) logits = self.classifier(logits) # Judge if it's a regression problem if self.d_output == 1: pred = logits.sigmoid() loss = mse_loss(pred, y) else: pred = logits loss = cross_entropy_loss(pred, y, smoothing=smothing) # calculate gradients loss.backward() # update parameters self.optimizer.step() # get metrics for logging acc = accuracy(pred, y, threshold=self.threshold) precision, recall, precision_avg, recall_avg = precision_recall( pred, y, self.d_output, threshold=self.threshold) total_loss += loss.item() batch_counter += 1 # training control state_dict = self.controller(batch_counter) if state_dict['step_to_print']: self.train_logger.info( '[TRAINING] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (state_dict['step'], loss, acc, precision[1], recall[1])) self.summary_writer.add_scalar('loss/train', loss, state_dict['step']) self.summary_writer.add_scalar('acc/train', acc, state_dict['step']) self.summary_writer.add_scalar('precision/train', precision[1], state_dict['step']) self.summary_writer.add_scalar('recall/train', recall[1], state_dict['step']) if state_dict['step_to_evaluate']: stop = self.val_epoch(eval_dataloader, device, state_dict['step']) state_dict['step_to_stop'] = stop if earlystop & stop: break if self.controller.current_step == self.controller.max_step: state_dict['step_to_stop'] = True break return state_dict