def train_epoch(self, train_dataloader, optimizer, device): train_loss, train_acc, losses = [], [], [] total_train_accuracy = 0 total_train_loss = 0 tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0 self.model.train() pbar_file = sys.stdout pbar_name = "train_batch" num_batches = len(train_dataloader.batch_sampler) with tqdm.tqdm(desc=pbar_name, total=num_batches, file=pbar_file) as pbar: for step, batch in enumerate(train_dataloader): b_input_ids = batch[0].to(device) b_input_mask = batch[1].to(device) b_labels = batch[2].to(device) self.model.zero_grad() # Forward pass output = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) # Log the train loss loss = output.loss logits = output.logits total_train_loss += loss.item() # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # Weight updates optimizer.step() logits = logits.detach().cpu() y = b_labels.to('cpu').squeeze(1) y_pred = torch.argmax(logits, dim=1) #.unsqueeze(1) tp, fp, tn, fn = calculate_acc(y_pred, y) tp_tot += tp fp_tot += fp tn_tot += tn fn_tot += fn total_train_accuracy += torch.sum(y_pred == y).float().item() pbar.set_description(f'{pbar_name} ({loss.item():.3f})') pbar.update() avg_train_accuracy = (total_train_accuracy / len(train_dataloader.dataset)) * 100 #print(" Training accuracy: {0:.2f}".format(avg_train_accuracy)) print( f"accuracy={avg_train_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}" ) if tp_tot + fn_tot > 0: print( f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f}, Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}" ) avg_train_loss = total_train_loss / len(train_dataloader) # Log the Avg. train loss print(" Training loss: {0:.4f}".format(avg_train_loss)) return EpochResult(avg_train_loss, avg_train_accuracy)
def train_epoch(self, train_dataloader, optimizer, loss_fn, device): total_train_accuracy = 0 total_train_loss = 0 tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0 pbar_file = sys.stdout pbar_name = "train_batch" num_batches = len(train_dataloader.batch_sampler) with tqdm.tqdm(desc=pbar_name, total=num_batches, file=pbar_file) as pbar: for step, batch in enumerate(train_dataloader): X, y = batch[0], batch[1] if y.shape[0] == 1: continue # Forward pass # X = torch.transpose(X, dim0=0, dim1=1) X = X.to(device) y = y.to(device) # self.lstm = self.lstm.float() y_pred_log_proba = self.forward(X) y = torch.squeeze(y).long() # should be of size (N,) # Backward pass optimizer.zero_grad() loss = loss_fn(y_pred_log_proba, y) loss.backward() # Weight updates optimizer.step() # Calculate accuracy total_train_loss += loss.item() y_pred = torch.argmax(y_pred_log_proba, dim=1) tp, fp, tn, fn = calculate_acc(y_pred, y) tp_tot += tp fp_tot += fp tn_tot += tn fn_tot += fn total_train_accuracy += torch.sum(y_pred == y).float().item() pbar.set_description(f'{pbar_name} ({loss.item():.3f})') pbar.update() avg_train_accuracy = (total_train_accuracy / len(train_dataloader.dataset)) * 100 # print(" Training accuracy: {0:.2f}".format(avg_train_accuracy)) print( f" accuracy={avg_train_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}" ) # if tp_tot + fn_tot > 0: # print(f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f}, Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}") avg_train_loss = total_train_loss / len(train_dataloader) # Log the Avg. train loss print(" Training loss: {0:.4f}".format(avg_train_loss)) return EpochResult(avg_train_loss, avg_train_accuracy)
def test_epoch(self, test_dataloader, device): self.model.eval() total_eval_accuracy = 0 total_eval_loss = 0 tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0 # Evaluate data for one epoch pbar_file = sys.stdout pbar_name = "test_batch" num_batches = len(test_dataloader.batch_sampler) with tqdm.tqdm(desc=pbar_name, total=num_batches, file=pbar_file) as pbar: for batch in test_dataloader: b_input_ids = batch[0].to(device) b_input_mask = batch[1].to(device) b_labels = batch[2].to(device) with torch.no_grad(): output = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) loss = output.loss logits = output.logits total_eval_loss += loss.item() logits = logits.detach().cpu() y = b_labels.to('cpu').squeeze(1) y_pred = torch.argmax(logits, dim=1) tp, fp, tn, fn = calculate_acc(y_pred, y) tp_tot += tp fp_tot += fp tn_tot += tn fn_tot += fn total_eval_accuracy += torch.sum(y_pred == y).float().item() pbar.set_description(f'{pbar_name} ({loss.item():.3f})') pbar.update() avg_val_accuracy = (total_eval_accuracy / len(test_dataloader.dataset)) * 100 # print(" Validation accuracy: {0:.2f}".format(avg_val_accuracy)) print( f"accuracy={avg_val_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}" ) if tp_tot + fn_tot > 0: print( f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f}, Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}" ) avg_val_loss = total_eval_loss / len(test_dataloader) # Log the Avg. validation accuracy print(" Validation Loss: {0:.4f}".format(avg_val_loss)) return EpochResult(avg_val_loss, avg_val_accuracy)
def test_epoch(self, test_dataloader, loss_fn, device): total_eval_accuracy = 0 total_eval_loss = 0 tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0 # Evaluate data for one epoch pbar_file = sys.stdout pbar_name = "test_batch" num_batches = len(test_dataloader.batch_sampler) with tqdm.tqdm(desc=pbar_name, total=num_batches, file=pbar_file) as pbar: for batch in test_dataloader: X, y = batch[0], batch[1] if y.shape[0] == 1: continue # Forward pass with torch.no_grad(): X = X.to(device) y = y.to(device) y_pred_log_proba = self.forward(X) y = torch.squeeze(y).long() loss = loss_fn(y_pred_log_proba, y) total_eval_loss += loss.item() y_pred = torch.argmax(y_pred_log_proba, dim=1) tp, fp, tn, fn = calculate_acc(y_pred, y) tp_tot += tp fp_tot += fp tn_tot += tn fn_tot += fn total_eval_accuracy += torch.sum(y_pred == y).float().item() pbar.set_description(f'{pbar_name} ({loss.item():.3f})') pbar.update() avg_val_accuracy = (total_eval_accuracy / len(test_dataloader.dataset)) * 100 print( f" accuracy={avg_val_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}" ) # if tp_tot + fn_tot > 0: # print(f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f}, Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}") avg_val_loss = total_eval_loss / len(test_dataloader) # Log the Avg. validation accuracy print(" Validation Loss: {0:.4f}".format(avg_val_loss)) return EpochResult(avg_val_loss, avg_val_accuracy)
def _foreach_batch(dl: DataLoader, forward_fn: Callable[[Any], BatchResult], verbose=True, max_batches=None) -> EpochResult: """ Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. """ losses = [] num_correct = 0 num_samples = len(dl.sampler) num_batches = len(dl.batch_sampler) if max_batches is not None: if max_batches < num_batches: num_batches = max_batches num_samples = num_batches * dl.batch_size if verbose: pbar_file = sys.stdout else: pbar_file = open(os.devnull, 'w') pbar_name = forward_fn.__name__ with tqdm.tqdm(desc=pbar_name, total=num_batches, file=pbar_file) as pbar: dl_iter = iter(dl) for batch_idx in range(num_batches): data = next(dl_iter) batch_res = forward_fn(data) pbar.set_description(f'{pbar_name} ({batch_res.loss:.3f})') pbar.update() losses.append(batch_res.loss) num_correct += batch_res.num_correct avg_loss = sum(losses) / num_batches accuracy = 100. * num_correct / num_samples pbar.set_description(f'{pbar_name} ' f'(Avg. Loss {avg_loss:.3f}, ' f'Accuracy {accuracy:.1f})') return EpochResult(losses=losses, accuracy=accuracy)