def train(self): # Setting the variables before starting the training avg_train_loss = AverageMeter() avg_train_acc = AverageMeter() best_val_acc = -np.inf for epoch in range(self.num_epochs): avg_train_loss.reset() avg_train_acc.reset() # Mini batch loop for batch_idx, batch in enumerate(tqdm(self.train_loader)): step = epoch * len(self.train_loader) + batch_idx # Get the model output for the batch and update the loss and accuracy meters train_loss, train_acc = self.train_step(batch) if self.args.scheduler == 'cycle': self.scheduler.step() avg_train_loss.update([train_loss.item()]) avg_train_acc.update([train_acc]) # Save the step checkpoint if needed # if step % self.save_every == 0: # step_chkpt_path = os.path.join(self.model_dir, # 'step_chkpt_{}_{}.pth'.format(epoch, step)) # print("Saving the model checkpoint for epoch {} at step {}".format(epoch, step)) # torch.save(self.model.state_dict(), step_chkpt_path) # Logging and validation check if step % self.print_every == 0: print( 'Epoch {}, batch {}, step {}, ' 'loss = {:.4f}, acc = {:.4f}, ' 'running averages: loss = {:.4f}, acc = {:.4f}'.format( epoch, batch_idx, step, train_loss.item(), train_acc, avg_train_loss.get(), avg_train_acc.get())) if step % self.val_every == 0: val_loss, val_acc = self.val() print('Val acc = {:.4f}, Val loss = {:.4f}'.format( val_acc, val_loss)) if self.visualize: self.writer.add_scalar('Val/loss', val_loss, step) self.writer.add_scalar('Val/acc', val_acc, step) # Update the save the best validation checkpoint if needed if val_acc > best_val_acc: best_val_acc = val_acc best_chkpt_path = os.path.join(self.model_dir, 'best_ckpt.pth') torch.save(self.model.state_dict(), best_chkpt_path) if self.args.scheduler == 'plateau': self.scheduler.step(val_acc) if self.visualize: # Log data to self.writer.add_scalar('Train/loss', train_loss.item(), step) self.writer.add_scalar('Train/acc', train_acc, step)
def train(self): # Setting the variables before starting the training avg_train_loss = AverageMeter() avg_train_acc = AverageMeter() text_avg_train_acc = AverageMeter() best_val_acc = -np.inf for epoch in range(self.num_epochs): self.model.print_frozen() avg_train_loss.reset() avg_train_acc.reset() text_avg_train_acc.reset() # Mini batch loop for batch_idx, batch in enumerate(tqdm(self.train_loader)): step = epoch * len(self.train_loader) + batch_idx # Get the model output for the batch and update the loss and accuracy meters train_loss, train_acc, text_train_acc = self.train_step(batch) if self.args.scheduler == 'cycle': self.scheduler.step() avg_train_loss.update([train_loss.item()]) avg_train_acc.update([train_acc]) text_avg_train_acc.update([text_train_acc]) # Logging and validation check if step % self.print_every == 0: print( 'Epoch {}, batch {}, step {}, ' 'loss = {:.4f}, acc_audio = {:.4f}, acc_text = {:.4f}, ' 'running averages: loss = {:.4f}, acc_audio = {:.4f}, acc_text = {:.4f}' .format(epoch, batch_idx, step, train_loss.item(), train_acc, text_train_acc, avg_train_loss.get(), avg_train_acc.get(), text_avg_train_acc.get())) if step % self.val_every == 0: val_loss, val_acc, text_val_acc = self.val() print( 'Val acc (audio) = {:.4f}, Val acc (text) = {:.4f}, Val loss = {:.4f}' .format(val_acc, text_val_acc, val_loss)) # Update the save the best validation checkpoint if needed audio_text_avg_acc = (val_acc + text_val_acc) / 2 if audio_text_avg_acc > best_val_acc: best_val_acc = audio_text_avg_acc #print('Start saving best checkpoint...) best_chkpt_path = os.path.join(self.model_dir, 'best_ckpt.pth') torch.save(self.model.state_dict(), best_chkpt_path) #print('Done saving best checkpoint!!!) if self.args.scheduler == 'plateau': self.scheduler.step(audio_text_avg_acc) self.model.unfreeze_one_layer()
def val(self): print('VALIDATING:') avg_val_loss = AverageMeter() avg_val_acc = AverageMeter() self.model.eval() for batch_idx, batch in enumerate(tqdm(self.val_loader)): metrics = self.compute_loss(batch) avg_val_acc.update(metrics['correct'].cpu().numpy()) avg_val_loss.update([metrics['loss']]) return avg_val_loss.get(), avg_val_acc.get()
def val(self): print('VALIDATING:') avg_val_loss = AverageMeter() #final loss avg_val_acc = AverageMeter() #audio acc text_avg_val_acc = AverageMeter() #text acc combined_avg_val_acc = AverageMeter() #combined acc self.model.eval() for batch_idx, batch in enumerate(tqdm(self.val_loader)): metrics = self.compute_loss(batch) avg_val_acc.update(metrics['correct'].cpu().numpy()) text_avg_val_acc.update(metrics['text_correct'].cpu().numpy()) combined_avg_val_acc.update(metrics['combined_correct'].cpu().numpy()) avg_val_loss.update([metrics['loss']]) return avg_val_loss.get(), avg_val_acc.get(), text_avg_val_acc.get(), combined_avg_val_acc.get()
def infer(self): self.load_model_for_eval() avg_test_loss = AverageMeter() avg_test_acc = AverageMeter() text_avg_test_acc = AverageMeter() for batch_idx, batch in enumerate(tqdm(self.test_loader)): # Get the model output and update the meters output = self.compute_loss(batch) avg_test_acc.update(output['correct'].cpu().numpy()) text_avg_test_acc.update(output['text_correct'].cpu().numpy()) avg_test_loss.update([output['loss']]) print('Final test acc (audio) = {:.4f}, final test acc (text) = {:.4f}, test loss = {:.4f}'.format(avg_test_acc.get(), text_avg_test_acc.get(), avg_test_loss.get())) return avg_test_loss.get(), avg_test_acc.get(), text_avg_test_acc.get()
def infer(self): self.load_model_for_eval() avg_test_loss = AverageMeter() avg_test_acc = AverageMeter() all_true_labels = [] all_pred_labels = [] all_audio_embeddings = [] all_text_embeddings = [] for batch_idx, batch in enumerate(tqdm(self.test_loader)): # Get the model output and update the meters output = self.compute_loss(batch) avg_test_acc.update(output['correct'].cpu().numpy()) avg_test_loss.update([output['loss']]) # Store the Predictions all_true_labels.append(batch['label'].cpu()) all_pred_labels.append(output['predicted'].cpu()) all_audio_embeddings.append( output['model_output']['audio_embed'].cpu()) all_text_embeddings.append( output['model_output']['text_embed'].cpu()) # Collect the predictions and embeddings for the full set all_true_labels = torch.cat(all_true_labels).numpy() all_pred_labels = torch.cat(all_pred_labels).numpy() all_audio_embeddings = torch.cat(all_audio_embeddings).numpy() all_text_embeddings = torch.cat(all_text_embeddings).numpy() # Save the embeddings and plot the confusion matrix np.savez_compressed('embeddings.npz', audio=all_audio_embeddings, text=all_text_embeddings, labels=all_true_labels) # cm = confusion_matrix(all_true_labels, all_pred_labels) # plot_confusion_matrix(cm, self.test_loader.dataset.labels_list(), normalize=True) print('Final test acc = {:.4f}, test loss = {:.4f}'.format( avg_test_acc.get(), avg_test_loss.get())) return avg_test_loss.get(), avg_test_acc.get()
def train(self): # Setting the variables before starting the training print('Loading checkpoint if checkpoint_dir is given...') self.load_checkpoint() avg_train_loss = AverageMeter() avg_train_acc = AverageMeter() text_avg_train_acc = AverageMeter() combined_avg_train_acc = AverageMeter() best_val_acc = -np.inf patience_counter = 0 best_epoch = self.num_epochs for epoch in range(self.num_epochs): self.model.print_frozen() avg_train_loss.reset() avg_train_acc.reset() text_avg_train_acc.reset() # Mini batch loop for batch_idx, batch in enumerate(tqdm(self.train_loader)): step = epoch * len(self.train_loader) + batch_idx # Get the model output for the batch and update the loss and accuracy meters train_loss, train_acc, text_train_acc = self.train_step(batch) if self.args.scheduler == 'cycle': self.scheduler.step() avg_train_loss.update([train_loss.item()]) avg_train_acc.update([train_acc]) text_avg_train_acc.update([text_train_acc]) # Logging and validation check if step % self.print_every == 0: print( 'Epoch {}, batch {}, step {}, ' 'loss = {:.4f}, acc_audio = {:.4f}, acc_text = {:.4f}, ' 'running averages: loss = {:.4f}, acc_audio = {:.4f}, acc_text = {:.4f}' .format(epoch, batch_idx, step, train_loss.item(), train_acc, text_train_acc, avg_train_loss.get(), avg_train_acc.get(), text_avg_train_acc.get())) if step % self.val_every == 0: val_loss, val_acc, text_val_acc, combined_val_acc = self.val( ) print( 'Val acc (audio) = {:.4f}, Val acc (text) = {:.4f}, Val acc (combined) = {:.4f}, Val loss = {:.4f}' .format(val_acc, text_val_acc, combined_val_acc, val_loss)) # Update the save the best validation checkpoint if needed if self.args.model_save_criteria == 'audio_text': cur_avg_acc = (val_acc + text_val_acc) / 2 else: #'combined' cur_avg_acc = combined_val_acc if cur_avg_acc > best_val_acc: #print('Start saving best check point at step{}...'.format(step)) best_val_acc = cur_avg_acc best_chkpt_path = os.path.join(self.model_dir, 'best_ckpt.pth') torch.save(self.model.state_dict(), best_chkpt_path) print('Done saving best check point!') if self.args.scheduler == 'plateau': self.scheduler.step(audio_text_avg_acc) print('------ End of epoch validation ------') val_loss, val_acc, text_val_acc, combined_val_acc = self.val() # Update the save the best validation checkpoint if needed if self.args.model_save_criteria == 'audio_text': cur_avg_acc = (val_acc + text_val_acc) / 2 else: #'combined' cur_avg_acc = combined_val_acc if cur_avg_acc > best_val_acc: #print('Start saving best check point at step{}...'.format(step)) best_val_acc = cur_avg_acc best_chkpt_path = os.path.join(self.model_dir, 'best_ckpt.pth') torch.save(self.model.state_dict(), best_chkpt_path) patience_counter = 0 best_epoch = epoch print('Done saving best check point! Patience counter reset!') else: patience_counter += 1 if patience_counter > self.max_patience: print( 'Reach max patience limit. Training stops! Best val acc achieved at epoch: {}.' .format(epoch)) break self.model.unfreeze_one_layer()