Beispiel #1
0
def validate(eyenet: EyeNet, val_loader: DataLoader) -> float:
    with torch.no_grad():
        val_losses = []
        for val_batch in val_loader:
            val_imgs = val_batch['img'].float().to(device)
            heatmaps = val_batch['heatmaps'].to(device)
            landmarks = val_batch['landmarks'].to(device)
            gaze = val_batch['gaze'].float().to(device)
            heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(val_imgs)
            heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss(
                heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred,
                gaze)
            loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss
            val_losses.append(loss.item())
        val_loss = np.mean(val_losses)
        return val_loss
Beispiel #2
0
def train_epoch(epoch: int, eyenet: EyeNet, optimizer,
                train_loader: DataLoader, val_loader: DataLoader,
                best_val_loss: float, checkpoint_fn: str,
                writer: SummaryWriter):

    N = len(train_loader)
    for i_batch, sample_batched in enumerate(train_loader):
        i_batch += N * epoch
        imgs = sample_batched['img'].float().to(device)
        heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(imgs)

        heatmaps = sample_batched['heatmaps'].to(device)
        landmarks = sample_batched['landmarks'].float().to(device)
        gaze = sample_batched['gaze'].float().to(device)

        heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss(
            heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred,
            gaze)

        loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        hm = np.mean(heatmaps[-1, 8:16].cpu().detach().numpy(), axis=0)
        hm_pred = np.mean(heatmaps_pred[-1, -1, 8:16].cpu().detach().numpy(),
                          axis=0)
        norm_hm = cv2.normalize(hm,
                                None,
                                alpha=0,
                                beta=1,
                                norm_type=cv2.NORM_MINMAX,
                                dtype=cv2.CV_32F)
        norm_hm_pred = cv2.normalize(hm_pred,
                                     None,
                                     alpha=0,
                                     beta=1,
                                     norm_type=cv2.NORM_MINMAX,
                                     dtype=cv2.CV_32F)

        if i_batch % 20 == 0:
            cv2.imwrite('true.jpg', norm_hm * 255)
            cv2.imwrite('pred.jpg', norm_hm_pred * 255)
            cv2.imwrite('eye.jpg', sample_batched['img'].numpy()[-1] * 255)

        writer.add_scalar("Training heatmaps loss", heatmaps_loss.item(),
                          i_batch)
        writer.add_scalar("Training landmarks loss", landmarks_loss.item(),
                          i_batch)
        writer.add_scalar("Training gaze loss", gaze_loss.item(), i_batch)
        writer.add_scalar("Training loss", loss.item(), i_batch)

        if i_batch > 0 and i_batch % 20 == 0:
            val_loss = validate(eyenet=eyenet, val_loader=val_loader)
            writer.add_scalar("validation loss", val_loss, i_batch)
            print('Epoch', epoch, 'Validation loss', val_loss)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(
                    {
                        'nstack': eyenet.nstack,
                        'nfeatures': eyenet.nfeatures,
                        'nlandmarks': eyenet.nlandmarks,
                        'best_val_loss': best_val_loss,
                        'model_state_dict': eyenet.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                    }, checkpoint_fn)

    return best_val_loss