def do_validation_pass(epoch, model, tel, loader): vis_images = None model.eval() with torch.no_grad(): for batch in progress_iter(loader, 'Validation'): in_var = batch['input'].to(global_opts['device'], torch.float32) target_var = batch['target'].to(global_opts['device'], torch.float32) mask_var = batch['joint_mask'].to(global_opts['device'], torch.float32) # Calculate predictions and loss out_var = model(in_var) loss = forward_loss(model, out_var, target_var, mask_var, batch['valid_depth']) tel['val_loss'].add(loss.sum().item()) calculate_performance_metrics( batch, loader.dataset, ensure_homogeneous(out_var.to(CPU, torch.float64).detach(), d=3), tel['val_mpjpe'], tel['val_pck']) if vis_images is None: preds = out_var.to(CPU, torch.float64).detach() vis_images = visualise_predictions(preds, batch, loader.dataset) tel['val_examples'].set_value(vis_images[:8])
def do_training_pass(epoch, model, tel, loader, scheduler, on_progress): if hasattr(scheduler, 'step'): scheduler.step(epoch) optimiser = scheduler.optimizer vis_images = None samples_processed = 0 model.train() for batch in generator_timer(progress_iter(loader, 'Training'), tel['data_load_time']): if hasattr(scheduler, 'batch_step'): scheduler.batch_step() with timer(tel['data_transfer_time']): in_var = batch['input'].to(global_opts['device'], torch.float32) target_var = batch['target'].to(global_opts['device'], torch.float32) mask_var = batch['joint_mask'].to(global_opts['device'], torch.float32) # Calculate predictions and loss with timer(tel['forward_time']): out_var = model(in_var) loss = forward_loss(model, out_var, target_var, mask_var, batch['valid_depth']) tel['train_loss'].add(loss.sum().item()) # Calculate accuracy metrics with timer(tel['eval_time']): calculate_performance_metrics( batch, loader.dataset, ensure_homogeneous(out_var.to(CPU, torch.float64).detach(), d=3), tel['train_mpjpe'], tel['train_pck']) # Calculate gradients with timer(tel['backward_time']): optimiser.zero_grad() loss.backward() # Update parameters with timer(tel['optim_time']): optimiser.step() # Update progress samples_processed += len(batch['input']) on_progress(samples_processed) if vis_images is None: preds = out_var.to(CPU, torch.float64).detach() vis_images = visualise_predictions(preds, batch, loader.dataset) tel['train_examples'].set_value(vis_images[:8])