示例#1
0
def validate(args, val_dataloader, model, auxiliarynet, epoch):

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    error = AverageMeter('error', ':6.2f')

    progress = ProgressMeter(len(val_dataloader),
                             batch_time,
                             data_time,
                             losses,
                             error,
                             prefix="Val Epoch: [{}]".format(epoch))

    model.eval()
    # auxiliarynet.eval()
    end = time.time()
    with torch.no_grad():
        end = time.time()
        for i, (patch, gaze_norm_g, head_norm,
                rot_vec_norm) in enumerate(val_dataloader):
            # measure data loading time
            data_time.update(time.time() - end)
            patch = patch.to(args.device)
            gaze_norm_g = gaze_norm_g.to(args.device)

            head_norm = head_norm.to(args.device)

            rot_vec_norm = rot_vec_norm.to(args.device)

            # model = model.to(args.device)
            gaze_pred, _ = model(patch)
            # hp_pred = auxiliarynet(features)

            head_norm = 10 * head_norm
            gaze_norm_g = 10 * gaze_norm_g
            # loss = criterion(gaze_norm_g, head_norm, gaze_pred[:,0:2], gaze_pred[:,2:4])

            angle_error = mean_angle_error(
                gaze_pred.cpu().detach().numpy() / 10,
                gaze_norm_g.cpu().detach().numpy() / 10,
                rot_vec_norm.cpu().detach().numpy())
            # losses.update(loss.item())
            error.update(angle_error)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % args.print_freq == 0:
                progress.print(i + 1)
                # img = patch.cpu().detach().numpy()[0].deepcopy()
                # to_visualize = draw_gaze(img[0], (0.25 * img.shape[1], 0.25 * img.shape[1]), gaze_pred,
                # gaze_norm_g, length=80.0, thickness=1)
                # cv2.imshow('vis', to_visualize)
                # cv2.waitKey(1)

    return losses.get_avg(), error.get_avg()
示例#2
0
def train(args, train_loader, model, auxiliarynet, criterion, optimizer,
          epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    error = AverageMeter('error', ':6.2f')

    progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, error,
                             prefix="Train Epoch: [{}]".format(epoch))
    # switch to train mode
    model.train()
    auxiliarynet.train()
    end = time.time()
    for batch_idx, (eyes, face, gaze_norm_g, head_norm, rot_vec_norm) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        eyes.requires_grad = False
        eyes = eyes.to(args.device)
        face.requires_grad = False
        face = face.to(args.device)

        gaze_norm_g.requires_grad = False
        gaze_norm_g = gaze_norm_g.to(args.device)

        head_norm.requires_grad = False
        head_norm = head_norm.to(args.device)

        rot_vec_norm.requires_grad = False
        rot_vec_norm = rot_vec_norm.to(args.device)

        face_feature = auxiliarynet(face)
        # gaze_pred, head_pred = model(eyes, face_feature)
        gaze_pred = model(eyes, face_feature)
        # print(features.size())
        
        head_norm = 100 * head_norm
        gaze_norm_g = 100 * gaze_norm_g
        loss = criterion(gaze_norm_g, head_norm, gaze_pred[:, 0:2], gaze_pred[:, 2:4])
        # loss = criterion(gaze_norm_g, head_norm, gaze_pred, None)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        angle_error = mean_angle_error(gaze_pred[:,0:2].cpu().detach().numpy()/100,
                                       gaze_norm_g.cpu().detach().numpy()/100, 
                                       rot_vec_norm.cpu().detach().numpy())
        
        losses.update(loss.item(), eyes.size(0))
        error.update(angle_error, eyes.size(0))

        if(batch_idx + 1) % args.print_freq == 0: 
            progress.print(batch_idx+1)
    return losses.get_avg(), error.get_avg()
示例#3
0
def train(device, train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    error = AverageMeter('error', ':6.2f')

    progress = ProgressMeter(len(train_loader),
                             batch_time,
                             data_time,
                             losses,
                             error,
                             prefix="Train Epoch: [{}]".format(epoch))
    # switch to train mode
    model.train()

    end = time.time()
    for batch_idx, (patch, gaze_norm_g, head_norm,
                    rot_vec_norm) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        patch.requires_grad = False
        patch = patch.to(device)

        gaze_norm_g.requires_grad = False
        gaze_norm_g = gaze_norm_g.to(device)

        head_norm.requires_grad = False
        head_norm = head_norm.to(device)

        rot_vec_norm.requires_grad = False
        rot_vec_norm = rot_vec_norm.to(device)

        model = model.to(device)

        gaze_norm_gt = torch.clamp((100 * (gaze_norm_g + 0.75)).long(),
                                   -74,
                                   75,
                                   out=None)

        gaze_p, gaze_y = model(patch)

        loss = criterion(gaze_p, gaze_norm_gt[:, 0]) + criterion(
            gaze_y, gaze_norm_gt[:, 1])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        _, preds_p = torch.max(gaze_p, 1)
        _, preds_y = torch.max(gaze_y, 1)
        gaze_pred = torch.cat(
            [preds_p.view(patch.size(0), 1),
             preds_y.view(patch.size(0), 1)], 1)
        angle_error = mean_angle_error(
            gaze_pred.cpu().detach().numpy().astype(float) / 100 - 0.75,
            gaze_norm_g.cpu().detach().numpy(),
            rot_vec_norm.cpu().detach().numpy())

        losses.update(loss.item(), patch.size(0))
        error.update(angle_error, patch.size(0))

        if (batch_idx + 1) % args.print_freq == 0:
            progress.print(batch_idx + 1)
    return losses.get_avg(), error.get_avg()
示例#4
0
def validate(device, val_dataloader, model, criterion, epoch, args):

    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    error = AverageMeter('error', ':6.2f')

    progress = ProgressMeter(len(val_dataloader),
                             batch_time,
                             data_time,
                             losses,
                             error,
                             prefix="Val Epoch: [{}]".format(epoch))

    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (patch, gaze_norm_g, head_norm,
                rot_vec_norm) in enumerate(val_dataloader):
            # measure data loading time
            data_time.update(time.time() - end)
            patch = patch.to(device)
            gaze_norm_g = gaze_norm_g.to(device)

            head_norm = head_norm.to(device)

            rot_vec_norm = rot_vec_norm.to(device)

            model = model.to(device)

            gaze_norm_gt = torch.clamp((100 * (gaze_norm_g + 0.75)).long(),
                                       -74,
                                       75,
                                       out=None)

            gaze_p, gaze_y = model(patch)

            loss = criterion(gaze_p, gaze_norm_gt[:, 0]) + criterion(
                gaze_y, gaze_norm_gt[:, 1])

            _, preds_p = torch.max(gaze_p, 1)
            _, preds_y = torch.max(gaze_y, 1)
            gaze_pred = torch.cat([
                preds_p.view(patch.size(0), 1),
                preds_y.view(patch.size(0), 1)
            ], 1)
            angle_error = mean_angle_error(
                gaze_pred.cpu().detach().numpy().astype(float) / 100 - 0.75,
                gaze_norm_g.cpu().detach().numpy(),
                rot_vec_norm.cpu().detach().numpy())
            losses.update(loss.item(), patch.size(0))
            error.update(angle_error, patch.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % args.print_freq == 0:
                progress.print(i + 1)

    return losses.get_avg(), error.get_avg()