def train_classifier_simple_v1(num_epochs, model, optimizer, device, train_loader, valid_loader=None, loss_fn=None, logging_interval=100, skip_epoch_stats=False): if loss_fn is None: loss_fn = F.cross_entropy start_time = time.time() for epoch in range(num_epochs): model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) # FORWARD AND BACK PROP logits = model(features) cost = loss_fn(logits, targets) optimizer.zero_grad() cost.backward() # UPDATE MODEL PARAMETERS optimizer.step() # LOGGING if not batch_idx % logging_interval: print('Epoch: %03d/%03d | Batch %04d/%04d | Loss: %.4f' % (epoch + 1, num_epochs, batch_idx, len(train_loader), cost)) if not skip_epoch_stats: model.eval() with torch.set_grad_enabled(False): # save memory during inference print('Epoch: %03d/%03d | Train Acc.: %.3f%% | Loss: %.3f' % (epoch + 1, num_epochs, compute_accuracy(model, train_loader, device), compute_epoch_loss(model, train_loader, device))) if valid_loader is not None: print( 'Epoch: %03d/%03d | Validation Acc.: %.3f%% | Loss: %.3f' % (epoch + 1, num_epochs, compute_accuracy(model, valid_loader, device), compute_epoch_loss(model, valid_loader, device))) print('Time elapsed: %.2f min' % ((time.time() - start_time) / 60)) print('Total Training Time: %.2f min' % ((time.time() - start_time) / 60))
def train_classifier_simple_v1(num_epochs, model, optimizer, device, train_loader, valid_loader=None, loss_fn=None, logging_interval=100, skip_epoch_stats=False): log_dict = {'train_loss_per_batch': [], 'train_acc_per_epoch': [], 'train_loss_per_epoch': [], 'valid_acc_per_epoch': [], 'valid_loss_per_epoch': []} if loss_fn is None: loss_fn = F.cross_entropy start_time = time.time() for epoch in range(num_epochs): model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) # FORWARD AND BACK PROP logits = model(features) loss = loss_fn(logits, targets) optimizer.zero_grad() loss.backward() # UPDATE MODEL PARAMETERS optimizer.step() # LOGGING log_dict['train_loss_per_batch'].append(loss.item()) if not batch_idx % logging_interval: print('Epoch: %03d/%03d | Batch %04d/%04d | Loss: %.4f' % (epoch+1, num_epochs, batch_idx, len(train_loader), loss)) if not skip_epoch_stats: model.eval() with torch.set_grad_enabled(False): # save memory during inference train_acc = compute_accuracy(model, train_loader, device) train_loss = compute_epoch_loss_classifier( model, train_loader, loss_fn, device) print('***Epoch: %03d/%03d | Train. Acc.: %.3f%% | Loss: %.3f' % ( epoch+1, num_epochs, train_acc, train_loss)) log_dict['train_loss_per_epoch'].append(train_loss.item()) log_dict['train_acc_per_epoch'].append(train_acc.item()) if valid_loader is not None: valid_acc = compute_accuracy(model, valid_loader, device) valid_loss = compute_epoch_loss_classifier( model, valid_loader, loss_fn, device) print('***Epoch: %03d/%03d | Valid. Acc.: %.3f%% | Loss: %.3f' % ( epoch+1, num_epochs, valid_acc, valid_loss)) log_dict['valid_loss_per_epoch'].append(valid_loss.item()) log_dict['valid_acc_per_epoch'].append(valid_acc.item()) print('Time elapsed: %.2f min' % ((time.time() - start_time)/60)) print('Total Training Time: %.2f min' % ((time.time() - start_time)/60)) return log_dict
def train_classifier_simple_v2(model, num_epochs, train_loader, valid_loader, test_loader, optimizer, device, logging_interval=50, best_model_save_path=None, scheduler=None, skip_train_acc=False, scheduler_on='valid_acc'): start_time = time.time() minibatch_loss_list, train_acc_list, valid_acc_list = [], [], [] best_valid_acc, best_epoch = -float('inf'), 0 for epoch in range(num_epochs): model.train() for batch_idx, (features, targets) in enumerate(train_loader): features = features.to(device) targets = targets.to(device) # ## FORWARD AND BACK PROP logits = model(features) loss = torch.nn.functional.cross_entropy(logits, targets) optimizer.zero_grad() loss.backward() # ## UPDATE MODEL PARAMETERS optimizer.step() # ## LOGGING minibatch_loss_list.append(loss.item()) if not batch_idx % logging_interval: print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} ' f'| Batch {batch_idx:04d}/{len(train_loader):04d} ' f'| Loss: {loss:.4f}') model.eval() with torch.no_grad(): # save memory during inference if not skip_train_acc: train_acc = compute_accuracy(model, train_loader, device=device).item() else: train_acc = float('nan') valid_acc = compute_accuracy(model, valid_loader, device=device).item() train_acc_list.append(train_acc) valid_acc_list.append(valid_acc) if valid_acc > best_valid_acc: best_valid_acc, best_epoch = valid_acc, epoch + 1 if best_model_save_path: torch.save(model.state_dict(), best_model_save_path) print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} ' f'| Train: {train_acc :.2f}% ' f'| Validation: {valid_acc :.2f}% ' f'| Best Validation ' f'(Ep. {best_epoch:03d}): {best_valid_acc :.2f}%') elapsed = (time.time() - start_time) / 60 print(f'Time elapsed: {elapsed:.2f} min') if scheduler is not None: if scheduler_on == 'valid_acc': scheduler.step(valid_acc_list[-1]) elif scheduler_on == 'minibatch_loss': scheduler.step(minibatch_loss_list[-1]) else: raise ValueError('Invalid `scheduler_on` choice.') elapsed = (time.time() - start_time) / 60 print(f'Total Training Time: {elapsed:.2f} min') test_acc = compute_accuracy(model, test_loader, device=device) print(f'Test accuracy {test_acc :.2f}%') return minibatch_loss_list, train_acc_list, valid_acc_list