def validate(eyenet: EyeNet, val_loader: DataLoader) -> float: with torch.no_grad(): val_losses = [] for val_batch in val_loader: val_imgs = val_batch['img'].float().to(device) heatmaps = val_batch['heatmaps'].to(device) landmarks = val_batch['landmarks'].to(device) gaze = val_batch['gaze'].float().to(device) heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(val_imgs) heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss( heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze) loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss val_losses.append(loss.item()) val_loss = np.mean(val_losses) return val_loss
nstack = checkpoint['nstack'] nfeatures = checkpoint['nfeatures'] nlandmarks = checkpoint['nlandmarks'] eyenet = EyeNet(nstack=nstack, nfeatures=nfeatures, nlandmarks=nlandmarks).to(device) eyenet.load_state_dict(checkpoint['model_state_dict']) with torch.no_grad(): errors = [] print('N', len(dataset)) for i, sample in enumerate(dataset): print(i) x = torch.tensor([sample['img']]).float().to(device) heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(x) gaze = sample['gaze'].reshape((1, 2)) gaze_pred = np.asarray(gaze_pred.cpu().numpy()) if sample['side'] == 'right': gaze_pred[0, 1] = -gaze_pred[0, 1] angular_error = util.gaze.angular_error(gaze, gaze_pred) errors.append(angular_error) print('---') print('error', angular_error) print('mean error', np.mean(errors)) print('side', sample['side']) print('gaze', gaze) print('gaze pred', gaze_pred)
def train_epoch(epoch: int, eyenet: EyeNet, optimizer, train_loader: DataLoader, val_loader: DataLoader, best_val_loss: float, checkpoint_fn: str, writer: SummaryWriter): N = len(train_loader) for i_batch, sample_batched in enumerate(train_loader): i_batch += N * epoch imgs = sample_batched['img'].float().to(device) heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(imgs) heatmaps = sample_batched['heatmaps'].to(device) landmarks = sample_batched['landmarks'].float().to(device) gaze = sample_batched['gaze'].float().to(device) heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss( heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze) loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss optimizer.zero_grad() loss.backward() optimizer.step() hm = np.mean(heatmaps[-1, 8:16].cpu().detach().numpy(), axis=0) hm_pred = np.mean(heatmaps_pred[-1, -1, 8:16].cpu().detach().numpy(), axis=0) norm_hm = cv2.normalize(hm, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) norm_hm_pred = cv2.normalize(hm_pred, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) if i_batch % 20 == 0: cv2.imwrite('true.jpg', norm_hm * 255) cv2.imwrite('pred.jpg', norm_hm_pred * 255) cv2.imwrite('eye.jpg', sample_batched['img'].numpy()[-1] * 255) writer.add_scalar("Training heatmaps loss", heatmaps_loss.item(), i_batch) writer.add_scalar("Training landmarks loss", landmarks_loss.item(), i_batch) writer.add_scalar("Training gaze loss", gaze_loss.item(), i_batch) writer.add_scalar("Training loss", loss.item(), i_batch) if i_batch > 0 and i_batch % 20 == 0: val_loss = validate(eyenet=eyenet, val_loader=val_loader) writer.add_scalar("validation loss", val_loss, i_batch) print('Epoch', epoch, 'Validation loss', val_loss) if val_loss < best_val_loss: best_val_loss = val_loss torch.save( { 'nstack': eyenet.nstack, 'nfeatures': eyenet.nfeatures, 'nlandmarks': eyenet.nlandmarks, 'best_val_loss': best_val_loss, 'model_state_dict': eyenet.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, checkpoint_fn) return best_val_loss