示例#1
0
class JitterLoss(nn.Module):
    """
    Compute loss by computing distances between
    (1) grid points transformed by ground-truth theta
    (2) grid points transformed by predicted theta_tr and theta_st
    """
    def __init__(self, use_cuda=True, grid_size=20):
        super(JitterLoss, self).__init__()
        # define virtual grid of points to be transformed
        axis_coords = np.linspace(-1, 1, grid_size)
        self.N = grid_size * grid_size
        # X and Y.shape: (20, 20)
        X, Y = np.meshgrid(axis_coords, axis_coords)
        # X and Y.shape: (1, 1, 400), P.shape: (1, 2, 400)
        X = np.reshape(X, (1, 1, self.N))
        Y = np.reshape(Y, (1, 1, self.N))
        P = np.concatenate((X, Y), 1)
        # self.P = Variable(torch.FloatTensor(P), requires_grad=False)
        self.P = torch.Tensor(P.astype(np.float32))
        self.P.requires_grad = False
        self.pointTnf = PointTnf(use_cuda=use_cuda)
        # self.theta_norm = torch.FloatTensor([-1, -1, -1, 0, 0, 0, 1, 1, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1]).view(1, -1)
        if use_cuda:
            self.P = self.P.cuda()
            # self.theta_norm = self.theta_norm.cuda()

    def forward(self, theta_st, theta_tr):
        # expand grid according to batch size
        # theta_st.shape & theta_tr.shape: (batch_size, 36) for tps
        batch_size = theta_st.size()[0]
        # P.shape: (batch_size, 2, 400)
        P = self.P.expand(batch_size, 2, self.N)
        # theta_norm = self.theta_norm.expand(batch_size, 18)

        P_prime_st = self.pointTnf.tpsPointTnf(
            theta_st[:, 18:].unsqueeze(2).unsqueeze(3), P)
        P_prime_tr = self.pointTnf.tpsPointTnf(
            theta_tr[:, 18:].unsqueeze(2).unsqueeze(3), P)

        loss_st = torch.sum(torch.pow(P_prime_st - P, 2), 1)
        loss_st = torch.mean(loss_st)

        loss_tr = torch.sum(torch.pow(P_prime_tr - P, 2), 1)
        loss_tr = torch.mean(loss_tr)

        loss = (loss_st + loss_tr) / 2
        return loss
示例#2
0
 def __init__(self, use_cuda=True, grid_size=20):
     super(TransformedGridLoss, self).__init__()
     # define virtual grid of points to be transformed
     axis_coords = np.linspace(-1, 1, grid_size)
     self.N = grid_size * grid_size
     # X and Y.shape: (20, 20)
     X, Y = np.meshgrid(axis_coords, axis_coords)
     # X and Y.shape: (1, 1, 400), P.shape: (1, 2, 400)
     X = np.reshape(X, (1, 1, self.N))
     Y = np.reshape(Y, (1, 1, self.N))
     P = np.concatenate((X, Y), 1)
     # self.P = Variable(torch.FloatTensor(P), requires_grad=False)
     self.P = torch.Tensor(P.astype(np.float32))
     self.P.requires_grad = False
     self.pointTnf = PointTnf(use_cuda=use_cuda)
     if use_cuda:
         self.P = self.P.cuda()
示例#3
0
class TransformedGridLoss(nn.Module):
    """
    Compute loss by computing distances between
    (1) grid points transformed by ground-truth theta
    (2) grid points transformed by predicted theta_tr and theta_st
    """
    def __init__(self, use_cuda=True, grid_size=20):
        super(TransformedGridLoss, self).__init__()
        # define virtual grid of points to be transformed
        axis_coords = np.linspace(-1, 1, grid_size)
        self.N = grid_size * grid_size
        # X and Y.shape: (20, 20)
        X, Y = np.meshgrid(axis_coords, axis_coords)
        # X and Y.shape: (1, 1, 400), P.shape: (1, 2, 400)
        X = np.reshape(X, (1, 1, self.N))
        Y = np.reshape(Y, (1, 1, self.N))
        P = np.concatenate((X, Y), 1)
        # self.P = Variable(torch.FloatTensor(P), requires_grad=False)
        self.P = torch.Tensor(P.astype(np.float32))
        self.P.requires_grad = False
        self.pointTnf = PointTnf(use_cuda=use_cuda)
        if use_cuda:
            self.P = self.P.cuda()

    def forward(self, theta_aff_tps, theta_aff, theta_GT):
        # expand grid according to batch size
        # theta.shape: (batch_size, 18) for tps, (batch_size, 6) for affine
        # theta_GT.shape: (batch_size, 18, 1, 1) for tps, (batch_size, 6) for affine
        batch_size = theta_aff_tps.size()[0]
        # P.shape: (batch_size, 2, 400)
        P = self.P.expand(batch_size, 2, self.N)
        # compute transformed grid points using estimated and GT tnfs
        # P_prime and P_prime_GT.shape: (batch_size, 2, 400)
        P_prime = self.pointTnf.tpsPointTnf(
            theta_aff_tps.unsqueeze(2).unsqueeze(3), P)
        P_prime = self.pointTnf.affPointTnf(theta_aff, P_prime)
        P_prime_GT = self.pointTnf.tpsPointTnf(theta_GT, P)
        # compute MSE loss on transformed grid points
        loss = torch.sum(torch.pow(P_prime - P_prime_GT, 2), 1)
        loss = torch.mean(loss)
        return loss
示例#4
0
def flow_metrics(batch, batch_start_idx, theta_det, theta_aff, theta_tps,
                 theta_afftps, results, args):
    result_path = args.flow_output_dir

    do_det = theta_det is not None
    do_aff = theta_aff is not None
    do_tps = theta_tps is not None
    do_aff_tps = theta_afftps is not None

    pt = PointTnf(use_cuda=args.cuda)

    batch_size = batch['source_im_info'].size(0)
    for b in range(batch_size):
        # Get H, W of source and target image
        h_src = int(batch['source_im_info'][b, 0].cpu().numpy())
        w_src = int(batch['source_im_info'][b, 1].cpu().numpy())
        h_tgt = int(batch['target_im_info'][b, 0].cpu().numpy())
        w_tgt = int(batch['target_im_info'][b, 1].cpu().numpy())

        # Generate grid for warping
        grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, w_tgt),
                                     np.linspace(-1, 1, h_tgt))
        # grid_X, grid_Y.shape: (1, h_tgt, w_tgt, 1)
        grid_X = torch.Tensor(grid_X).unsqueeze(0).unsqueeze(3)
        grid_Y = torch.Tensor(grid_Y).unsqueeze(0).unsqueeze(3)
        grid_X.requires_grad = False
        grid_Y.requires_grad = False
        if args.cuda:
            grid_X = grid_X.cuda()
            grid_Y = grid_Y.cuda()
        # Reshape to vector, grid_X_vec, grid_Y_vec.shape: (1, 1, h_tgt * w_tgt)
        grid_X_vec = grid_X.view(1, 1, -1)
        grid_Y_vec = grid_Y.view(1, 1, -1)
        # grid_XY_vec.shape: (1, 2, h_tgt * w_tgt)
        grid_XY_vec = torch.cat((grid_X_vec, grid_Y_vec), 1)

        # Transform vector of points to grid
        def pointsToGrid(x, h_tgt=h_tgt, w_tgt=w_tgt):
            return x.contiguous().view(1, 2, h_tgt, w_tgt).permute(0, 2, 3, 1)

        idx = batch_start_idx + b

        if do_det:
            grid_det = pointsToGrid(
                pt.affPointTnf(theta=theta_det[b, :].unsqueeze(0),
                               points=grid_XY_vec))
            flow_det = th_sampling_grid_to_np_flow(source_grid=grid_det,
                                                   h_src=h_src,
                                                   w_src=w_src)
            flow_det_path = os.path.join(result_path, 'det',
                                         batch['flow_path'][b])
            # create_file_path(flow_det_path)
            # write_flo_file(flow_det, flow_det_path)

        if do_aff:
            if do_det:
                key = 'det_aff'
                grid_aff = pointsToGrid(
                    pt.affPointTnf(theta=theta_det[b, :].unsqueeze(0),
                                   points=pt.affPointTnf(
                                       theta=theta_aff[b, :].unsqueeze(0),
                                       points=grid_XY_vec)))
            else:
                key = 'aff'
                grid_aff = pointsToGrid(
                    pt.affPointTnf(theta=theta_aff[b, :].unsqueeze(0),
                                   points=grid_XY_vec))
            flow_aff = th_sampling_grid_to_np_flow(source_grid=grid_aff,
                                                   h_src=h_src,
                                                   w_src=w_src)
            flow_aff_path = os.path.join(result_path, key,
                                         batch['flow_path'][b])
            # create_file_path(flow_aff_path)
            # write_flo_file(flow_aff, flow_aff_path)

        if do_tps:
            # vis = Visdom()
            # flow_gt = batch['flow_gt'][b]
            # vis.heatmap(flow_gt.cpu().numpy()[:, ::-1, 0])
            # vis.heatmap(flow_gt.cpu().numpy()[:, ::-1, 1])
            # flow_gt_img = visualize_flow(flow_gt.cpu().numpy())
            # vis.image(flow_gt_img.transpose((2, 0, 1)))
            # vis.mesh(grid_XY_vec.squeeze().transpose(0, 1), opts=dict(opacity=0.3))
            # vis.mesh(pt.tpsPointTnf(theta=theta_tps[b, :].unsqueeze(0), points=grid_XY_vec).squeeze().transpose(0, 1), opts=dict(opacity=0.3))
            # grid_XY = pointsToGrid(grid_XY_vec).squeeze(0)
            # in_bound_mask = (grid_XY[:, :, 0] > -1) & (grid_XY[:, :, 0] < 1) & (grid_XY[:, :, 1] > -1) & (grid_XY[:, :, 1] < 1)
            # vis.heatmap(in_bound_mask)

            # Get sampling grid with predicted TPS parameters, grid_tps.shape: (1, h_tgt, w_tgt, 2)
            grid_tps = pointsToGrid(
                pt.tpsPointTnf(theta=theta_tps[b, :].unsqueeze(0),
                               points=grid_XY_vec))
            # Transform sampling grid to flow
            flow_tps = th_sampling_grid_to_np_flow(source_grid=grid_tps,
                                                   h_src=h_src,
                                                   w_src=w_src)
            flow_tps_path = os.path.join(result_path, 'tps',
                                         batch['flow_path'][b])
            # create_file_path(flow_tps_path)
            # write_flo_file(flow_tps, flow_tps_path)

        if do_aff_tps:
            if do_det:
                key = 'det_aff_tps'
                grid_aff_tps = pointsToGrid(
                    pt.affPointTnf(
                        theta=theta_det[b, :].unsqueeze(0),
                        points=pt.affPointTnf(
                            theta=theta_aff[b, :].unsqueeze(0),
                            points=pt.tpsPointTnf(
                                theta=theta_afftps[b, :].unsqueeze(0),
                                points=grid_XY_vec))))
            else:
                key = 'afftps'
                grid_aff_tps = pointsToGrid(
                    pt.affPointTnf(theta=theta_aff[b, :].unsqueeze(0),
                                   points=pt.tpsPointTnf(
                                       theta=theta_afftps[b, :].unsqueeze(0),
                                       points=grid_XY_vec)))
            flow_aff_tps = th_sampling_grid_to_np_flow(
                source_grid=grid_aff_tps, h_src=h_src, w_src=w_src)
            flow_aff_tps_path = os.path.join(result_path, key,
                                             batch['flow_path'][b])
            # create_file_path(flow_aff_tps_path)
            # write_flo_file(flow_aff_tps, flow_aff_tps_path)

        idx = batch_start_idx + b
    return results
示例#5
0
def area_metrics(batch, batch_start_idx, theta_det, theta_aff, theta_tps,
                 theta_afftps, results, args):
    do_det = theta_det is not None
    do_aff = theta_aff is not None
    do_tps = theta_tps is not None
    do_aff_tps = theta_afftps is not None

    batch_size = batch['source_im_info'].size(0)

    pt = PointTnf(use_cuda=args.cuda)

    for b in range(batch_size):
        # Get H, W of source and target image
        h_src = int(batch['source_im_info'][b, 0].cpu().numpy())
        w_src = int(batch['source_im_info'][b, 1].cpu().numpy())
        h_tgt = int(batch['target_im_info'][b, 0].cpu().numpy())
        w_tgt = int(batch['target_im_info'][b, 1].cpu().numpy())

        # Transform annotated polygon to mask using given coordinates of key points
        # target_mask_np.shape: (h_tgt, w_tgt), target_mask.shape: (1, 1, h_tgt, w_tgt)
        target_mask_np, target_mask = poly_str_to_mask(
            poly_x_str=batch['target_polygon'][0][b],
            poly_y_str=batch['target_polygon'][1][b],
            out_h=h_tgt,
            out_w=w_tgt,
            use_cuda=args.cuda)
        source_mask_np, source_mask = poly_str_to_mask(
            poly_x_str=batch['source_polygon'][0][b],
            poly_y_str=batch['source_polygon'][1][b],
            out_h=h_src,
            out_w=w_src,
            use_cuda=args.cuda)

        # Generate grid for warping
        grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, w_tgt),
                                     np.linspace(-1, 1, h_tgt))
        # grid_X, grid_Y.shape: (1, h_tgt, w_tgt, 1)
        grid_X = torch.Tensor(grid_X.astype(
            np.float32)).unsqueeze(0).unsqueeze(3)
        grid_Y = torch.Tensor(grid_Y.astype(
            np.float32)).unsqueeze(0).unsqueeze(3)
        grid_X.requires_grad = False
        grid_Y.requires_grad = False
        if args.cuda:
            grid_X = grid_X.cuda()
            grid_Y = grid_Y.cuda()
        # Reshape to vector, grid_X_vec, grid_Y_vec.shape: (1, 1, h_tgt * w_tgt)
        grid_X_vec = grid_X.view(1, 1, -1)
        grid_Y_vec = grid_Y.view(1, 1, -1)
        # grid_XY_vec.shape: (1, 2, h_tgt * w_tgt)
        grid_XY_vec = torch.cat((grid_X_vec, grid_Y_vec), 1)

        # Transform vector of points to grid
        def pointsToGrid(x, h_tgt=h_tgt, w_tgt=w_tgt):
            return x.contiguous().view(1, 2, h_tgt, w_tgt).permute(0, 2, 3, 1)

        idx = batch_start_idx + b

        if do_det:
            grid_det = pointsToGrid(
                pt.affPointTnf(theta=theta_det[b, :].unsqueeze(0),
                               points=grid_XY_vec))
            warped_mask_det = F.grid_sample(source_mask, grid_det)
            flow_det = th_sampling_grid_to_np_flow(source_grid=grid_det,
                                                   h_src=h_src,
                                                   w_src=w_src)

            results['det']['intersection_over_union'][
                idx] = intersection_over_union(warped_mask=warped_mask_det,
                                               target_mask=target_mask)
            results['det']['label_transfer_accuracy'][
                idx] = label_transfer_accuracy(warped_mask=warped_mask_det,
                                               target_mask=target_mask)
            results['det']['localization_error'][idx] = localization_error(
                source_mask_np=source_mask_np,
                target_mask_np=target_mask_np,
                flow_np=flow_det)

        if do_aff:
            if do_det:
                key = 'det_aff'
                grid_aff = pointsToGrid(
                    pt.affPointTnf(theta=theta_det[b, :].unsqueeze(0),
                                   points=pt.affPointTnf(
                                       theta=theta_aff[b, :].unsqueeze(0),
                                       points=grid_XY_vec)))
            else:
                key = 'aff'
                grid_aff = pointsToGrid(
                    pt.affPointTnf(theta=theta_aff[b, :].unsqueeze(0),
                                   points=grid_XY_vec))
            warped_mask_aff = F.grid_sample(source_mask, grid_aff)
            flow_aff = th_sampling_grid_to_np_flow(source_grid=grid_aff,
                                                   h_src=h_src,
                                                   w_src=w_src)

            results[key]['intersection_over_union'][
                idx] = intersection_over_union(warped_mask=warped_mask_aff,
                                               target_mask=target_mask)
            results[key]['label_transfer_accuracy'][
                idx] = label_transfer_accuracy(warped_mask=warped_mask_aff,
                                               target_mask=target_mask)
            results[key]['localization_error'][idx] = localization_error(
                source_mask_np=source_mask_np,
                target_mask_np=target_mask_np,
                flow_np=flow_aff)

        if do_tps:
            # Get sampling grid with predicted TPS parameters, grid_tps.shape: (1, h_tgt, w_tgt, 2)
            grid_tps = pointsToGrid(
                pt.tpsPointTnf(theta=theta_tps[b, :].unsqueeze(0),
                               points=grid_XY_vec))
            warped_mask_tps = F.grid_sample(
                source_mask, grid_tps)  # Sampling source_mask with warped grid
            # Transform sampling grid to flow
            flow_tps = th_sampling_grid_to_np_flow(source_grid=grid_tps,
                                                   h_src=h_src,
                                                   w_src=w_src)

            results['tps']['intersection_over_union'][
                idx] = intersection_over_union(warped_mask=warped_mask_tps,
                                               target_mask=target_mask)
            results['tps']['label_transfer_accuracy'][
                idx] = label_transfer_accuracy(warped_mask=warped_mask_tps,
                                               target_mask=target_mask)
            results['tps']['localization_error'][idx] = localization_error(
                source_mask_np=source_mask_np,
                target_mask_np=target_mask_np,
                flow_np=flow_tps)

        if do_aff_tps:
            if do_det:
                key = 'det_aff_tps'
                grid_aff_tps = pointsToGrid(
                    pt.affPointTnf(
                        theta=theta_det[b, :].unsqueeze(0),
                        points=pt.affPointTnf(
                            theta=theta_aff[b, :].unsqueeze(0),
                            points=pt.tpsPointTnf(
                                theta=theta_afftps[b, :].unsqueeze(0),
                                points=grid_XY_vec))))
            else:
                key = 'afftps'
                grid_aff_tps = pointsToGrid(
                    pt.affPointTnf(theta=theta_aff[b, :].unsqueeze(0),
                                   points=pt.tpsPointTnf(
                                       theta=theta_afftps[b, :].unsqueeze(0),
                                       points=grid_XY_vec)))
            warped_mask_aff_tps = F.grid_sample(source_mask, grid_aff_tps)
            flow_aff_tps = th_sampling_grid_to_np_flow(
                source_grid=grid_aff_tps, h_src=h_src, w_src=w_src)

            results[key]['intersection_over_union'][
                idx] = intersection_over_union(warped_mask=warped_mask_aff_tps,
                                               target_mask=target_mask)
            results[key]['label_transfer_accuracy'][
                idx] = label_transfer_accuracy(warped_mask=warped_mask_aff_tps,
                                               target_mask=target_mask)
            results[key]['localization_error'][idx] = localization_error(
                source_mask_np=source_mask_np,
                target_mask_np=target_mask_np,
                flow_np=flow_aff_tps)

    return results
示例#6
0
def pck_metric(batch, batch_start_idx, theta_det, theta_aff, theta_tps,
               theta_afftps, results, args):
    alpha = args.pck_alpha
    do_det = theta_det is not None
    do_aff = theta_aff is not None
    do_tps = theta_tps is not None
    do_aff_tps = theta_afftps is not None

    source_im_size = batch['source_im_info'][:, 0:3]
    target_im_size = batch['target_im_info'][:, 0:3]

    source_points = batch['source_points']
    target_points = batch['target_points']

    # Instantiate point transformer
    pt = PointTnf(use_cuda=args.cuda, tps_reg_factor=args.tps_reg_factor)
    # pt = PointTnf(use_cuda=args.cuda)

    # warp points with estimated transformations
    target_points_norm = PointsToUnitCoords(P=target_points,
                                            im_size=target_im_size)

    if do_det:
        # Affine transformation only based on object detection
        warped_points_det_norm = pt.affPointTnf(theta=theta_det,
                                                points=target_points_norm)
        warped_points_det = PointsToPixelCoords(P=warped_points_det_norm,
                                                im_size=source_im_size)

    if do_aff:
        # do affine only
        warped_points_aff_norm = pt.affPointTnf(theta=theta_aff,
                                                points=target_points_norm)
        if do_det:
            warped_points_aff_norm = pt.affPointTnf(
                theta=theta_det, points=warped_points_aff_norm)
        warped_points_aff = PointsToPixelCoords(P=warped_points_aff_norm,
                                                im_size=source_im_size)

    if do_tps:
        # do tps only
        warped_points_tps_norm = pt.tpsPointTnf(theta=theta_tps,
                                                points=target_points_norm)
        warped_points_tps = PointsToPixelCoords(P=warped_points_tps_norm,
                                                im_size=source_im_size)

    if do_aff_tps:
        # do tps+affine
        warped_points_aff_tps_norm = pt.tpsPointTnf(theta=theta_afftps,
                                                    points=target_points_norm)
        warped_points_aff_tps_norm = pt.affPointTnf(
            theta=theta_aff, points=warped_points_aff_tps_norm)
        if do_det:
            warped_points_aff_tps_norm = pt.affPointTnf(
                theta=theta_det, points=warped_points_aff_tps_norm)
        warped_points_aff_tps = PointsToPixelCoords(
            P=warped_points_aff_tps_norm, im_size=source_im_size)

    L_pck = batch['L_pck']

    current_batch_size = batch['source_im_info'].size(0)
    indices = range(batch_start_idx, batch_start_idx + current_batch_size)

    # import pdb; pdb.set_trace()
    if do_det:
        pck_det = pck(source_points, warped_points_det, L_pck, alpha)

    if do_aff:
        pck_aff = pck(source_points, warped_points_aff, L_pck, alpha)

    if do_tps:
        pck_tps = pck(source_points, warped_points_tps, L_pck, alpha)

    if do_aff_tps:
        pck_aff_tps = pck(source_points, warped_points_aff_tps, L_pck, alpha)

    if do_det:
        results['det']['pck'][indices] = pck_det.unsqueeze(1).cpu().numpy()
    if do_aff:
        if do_det:
            key = 'det_aff'
        else:
            key = 'aff'
        results[key]['pck'][indices] = pck_aff.unsqueeze(1).cpu().numpy()
    if do_tps:
        results['tps']['pck'][indices] = pck_tps.unsqueeze(1).cpu().numpy()
    if do_aff_tps:
        if do_det:
            key = 'det_aff_tps'
        else:
            key = 'afftps'
        results[key]['pck'][indices] = pck_aff_tps.unsqueeze(1).cpu().numpy()

    return results
def vis_pf(vis,
           dataloader,
           theta,
           theta_weak,
           theta_inver,
           theta_weak_inver,
           results,
           results_weak,
           dataset_name,
           use_cuda=True):
    # Visualize watch images
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    pt = PointTnf(use_cuda=use_cuda)

    watch_images = torch.ones(len(dataloader) * 6, 3, 280, 240)
    watch_keypoints = -torch.ones(len(dataloader) * 6, 2, 20)
    if use_cuda:
        watch_images = watch_images.cuda()
        watch_keypoints = watch_keypoints.cuda()
    num_points = np.ones(len(dataloader) * 6).astype(np.int8)
    correct_index = list()
    image_names = list()
    metrics = list()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and theta_inver
        theta_aff = theta['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps = theta['aff_tps'][batch_idx].unsqueeze(0)
        theta_weak_aff = theta_weak['aff'][batch_idx].unsqueeze(0)
        theta_weak_aff_tps = theta_weak['aff_tps'][batch_idx].unsqueeze(0)

        theta_aff_inver = theta_inver['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps_inver = theta_inver['aff_tps'][batch_idx].unsqueeze(0)
        theta_weak_aff_inver = theta_weak_inver['aff'][batch_idx].unsqueeze(0)
        theta_weak_aff_tps_inver = theta_weak_inver['aff_tps'][
            batch_idx].unsqueeze(0)

        # Warped image
        warped_aff = affTnf(batch['source_image'], theta_aff)
        warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps)
        warped_weak_aff = affTnf(batch['source_image'], theta_weak_aff)
        warped_weak_aff_tps = tpsTnf(warped_weak_aff, theta_weak_aff_tps)

        watch_images[batch_idx * 6, :, 0:240, :] = batch['source_image']
        watch_images[batch_idx * 6 + 1, :, 0:240, :] = warped_aff
        watch_images[batch_idx * 6 + 2, :, 0:240, :] = warped_aff_tps
        watch_images[batch_idx * 6 + 3, :, 0:240, :] = batch['target_image']
        watch_images[batch_idx * 6 + 4, :, 0:240, :] = warped_weak_aff
        watch_images[batch_idx * 6 + 5, :, 0:240, :] = warped_weak_aff_tps

        # Warped keypoints
        source_im_size = batch['source_im_info'][:, 0:3]
        target_im_size = batch['target_im_info'][:, 0:3]

        source_points = batch['source_points']
        target_points = batch['target_points']

        source_points_norm = PointsToUnitCoords(P=source_points,
                                                im_size=source_im_size)
        target_points_norm = PointsToUnitCoords(P=target_points,
                                                im_size=target_im_size)

        warped_points_aff_norm = pt.affPointTnf(theta=theta_aff_inver,
                                                points=source_points_norm)
        warped_points_aff = PointsToPixelCoords(P=warped_points_aff_norm,
                                                im_size=target_im_size)
        pck_aff, index_aff, N_pts = pck(target_points, warped_points_aff,
                                        dataset_name)
        warped_points_aff = relocate(warped_points_aff, target_im_size)

        warped_points_aff_tps_norm = pt.tpsPointTnf(theta=theta_aff_tps_inver,
                                                    points=source_points_norm)
        warped_points_aff_tps_norm = pt.affPointTnf(
            theta=theta_aff_inver, points=warped_points_aff_tps_norm)
        warped_points_aff_tps = PointsToPixelCoords(
            P=warped_points_aff_tps_norm, im_size=target_im_size)
        pck_aff_tps, index_aff_tps, _ = pck(target_points,
                                            warped_points_aff_tps,
                                            dataset_name)
        warped_points_aff_tps = relocate(warped_points_aff_tps, target_im_size)

        warped_points_weak_aff_norm = pt.affPointTnf(
            theta=theta_weak_aff_inver, points=source_points_norm)
        warped_points_weak_aff = PointsToPixelCoords(
            P=warped_points_weak_aff_norm, im_size=target_im_size)
        pck_weak_aff, index_weak_aff, _ = pck(target_points,
                                              warped_points_weak_aff,
                                              dataset_name)
        warped_points_weak_aff = relocate(warped_points_weak_aff,
                                          target_im_size)

        warped_points_weak_aff_tps_norm = pt.tpsPointTnf(
            theta=theta_weak_aff_tps_inver, points=source_points_norm)
        warped_points_weak_aff_tps_norm = pt.affPointTnf(
            theta=theta_weak_aff_inver, points=warped_points_weak_aff_tps_norm)
        warped_points_weak_aff_tps = PointsToPixelCoords(
            P=warped_points_weak_aff_tps_norm, im_size=target_im_size)
        pck_weak_aff_tps, index_weak_aff_tps, _ = pck(
            target_points, warped_points_weak_aff_tps, dataset_name)
        warped_points_weak_aff_tps = relocate(warped_points_weak_aff_tps,
                                              target_im_size)

        watch_keypoints[batch_idx * 6, :, :N_pts] = relocate(
            batch['source_points'], source_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 +
                        1, :, :N_pts] = warped_points_aff[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 +
                        2, :, :N_pts] = warped_points_aff_tps[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 + 3, :, :N_pts] = relocate(
            batch['target_points'], target_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * 6 +
                        4, :, :N_pts] = warped_points_weak_aff[:, :, :N_pts]
        watch_keypoints[
            batch_idx * 6 +
            5, :, :N_pts] = warped_points_weak_aff_tps[:, :, :N_pts]

        num_points[batch_idx * 6:batch_idx * 6 + 6] = N_pts

        correct_index.append(np.arange(N_pts))
        correct_index.append(index_aff)
        correct_index.append(index_aff_tps)
        correct_index.append(np.arange(N_pts))
        correct_index.append(index_weak_aff)
        correct_index.append(index_weak_aff_tps)

        image_names.append('Source')
        image_names.append('Aff')
        image_names.append('Aff_tps')
        image_names.append('Target')
        image_names.append('Rocco_aff')
        image_names.append('Rocco_aff_tps')

        metrics.append('')
        metrics.append('PCK: {:.2%}'.format(pck_aff))
        metrics.append('PCK: {:.2%}'.format(pck_aff_tps))
        metrics.append('')
        metrics.append('PCK: {:.2%}'.format(pck_weak_aff))
        metrics.append('PCK: {:.2%}'.format(pck_weak_aff_tps))

    opts = dict(jpgquality=100, title=dataset_name)
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                0:240, :],
                                                   forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)
    watch_keypoints = watch_keypoints.cpu().numpy()

    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % 6 == 1 or (i + 1) % 6 == 4:
            pos_pck = (0, 0)
        else:
            pos_pck = (70, 275)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5,
                    (0, 0, 0), 1)
        cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0),
                    1)
        if (i + 1) % 6 == 4:
            for j in range(num_points[i]):
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA)
        else:
            for j in correct_index[i]:
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)
                cv2.drawMarker(watch_images[i],
                               (watch_keypoints[i + 3 - (i % 6), 0, j],
                                watch_keypoints[i + 3 - (i % 6), 1, j]),
                               colors[j], cv2.MARKER_DIAMOND, 12, 2,
                               cv2.LINE_AA)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=3, padding=3),
              opts=opts)
def vis_pf(vis,
           dataloader,
           theta_1,
           theta_2,
           theta_inver_1,
           theta_inver_2,
           results_1,
           results_2,
           dataset_name,
           use_cuda=True):
    # Visualize watch images
    tpsTnf_1 = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    tpsTnf_2 = GeometricTnf2(geometric_model='tps', use_cuda=use_cuda)
    pt_1 = PointTnf(use_cuda=use_cuda)
    pt_2 = PointTPS(use_cuda=use_cuda)

    group_size = 4
    watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240)
    watch_keypoints = -torch.ones(len(dataloader) * group_size, 2, 20)
    if use_cuda:
        watch_images = watch_images.cuda()
        watch_keypoints = watch_keypoints.cuda()
    num_points = np.ones(len(dataloader) * 6).astype(np.int8)
    correct_index = list()
    image_names = list()
    metrics = list()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        # Theta and theta_inver
        theta_tps_1 = theta_1['tps'][batch_idx].unsqueeze(0)
        theta_tps_2 = theta_2['tps'][batch_idx].unsqueeze(0)

        thetai_tps_1 = theta_inver_1['tps'][batch_idx].unsqueeze(0)
        thetai_tps_2 = theta_inver_2['tps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_tps_1 = tpsTnf_1(batch['source_image'], theta_tps_1)
        warped_tps_2 = tpsTnf_2(batch['source_image'], theta_tps_2)

        watch_images[batch_idx * group_size, :,
                     0:240, :] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :, 0:240, :] = warped_tps_1
        watch_images[batch_idx * group_size + 2, :, 0:240, :] = warped_tps_2
        watch_images[batch_idx * group_size + 3, :,
                     0:240, :] = batch['target_image']

        # Warped keypoints
        source_im_size = batch['source_im_info'][:, 0:3]
        target_im_size = batch['target_im_info'][:, 0:3]

        source_points = batch['source_points']
        target_points = batch['target_points']

        source_points_norm = PointsToUnitCoords(P=source_points,
                                                im_size=source_im_size)
        target_points_norm = PointsToUnitCoords(P=target_points,
                                                im_size=target_im_size)

        warped_points_tps_norm_1 = pt_1.tpsPointTnf(theta=thetai_tps_1,
                                                    points=source_points_norm)
        warped_points_tps_1 = PointsToPixelCoords(P=warped_points_tps_norm_1,
                                                  im_size=target_im_size)
        pck_tps_1, index_tps_1, N_pts = pck(target_points, warped_points_tps_1,
                                            dataset_name)
        warped_points_tps_1 = relocate(warped_points_tps_1, target_im_size)

        warped_points_tps_norm_2 = pt_2.tpsPointTnf(theta=thetai_tps_2,
                                                    points=source_points_norm)
        warped_points_tps_2 = PointsToPixelCoords(P=warped_points_tps_norm_2,
                                                  im_size=target_im_size)
        pck_tps_2, index_tps_2, _ = pck(target_points, warped_points_tps_2,
                                        dataset_name)
        warped_points_tps_2 = relocate(warped_points_tps_2, target_im_size)

        watch_keypoints[batch_idx * group_size, :, :N_pts] = relocate(
            batch['source_points'], source_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size +
                        1, :, :N_pts] = warped_points_tps_1[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size +
                        2, :, :N_pts] = warped_points_tps_2[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size + 3, :, :N_pts] = relocate(
            batch['target_points'], target_im_size)[:, :, :N_pts]

        num_points[batch_idx * group_size:batch_idx * group_size +
                   group_size] = N_pts

        correct_index.append(np.arange(N_pts))
        correct_index.append(index_tps_1)
        correct_index.append(index_tps_2)
        correct_index.append(np.arange(N_pts))

        image_names.append('Source')
        image_names.append('TPS')
        image_names.append('TPS_Jitter')
        image_names.append('Target')

        metrics.append('')
        metrics.append('PCK: {:.2%}'.format(pck_tps_1))
        metrics.append('PCK: {:.2%}'.format(pck_tps_2))
        metrics.append('')

    opts = dict(jpgquality=100, title=dataset_name)
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :,
                                                                0:240, :],
                                                   forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3,
                                        1).cpu().numpy().astype(np.uint8)
    watch_keypoints = watch_keypoints.cpu().numpy()

    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % group_size == 1 or (i + 1) % group_size == 0:
            pos_pck = (0, 0)
        else:
            pos_pck = (70, 275)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5,
                    (0, 0, 0), 1)
        cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0),
                    1)
        if (i + 1) % group_size == 0:
            for j in range(num_points[i]):
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA)
        else:
            for j in correct_index[i]:
                cv2.drawMarker(
                    watch_images[i],
                    (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]),
                    colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)
                cv2.drawMarker(watch_images[i],
                               (watch_keypoints[i + (group_size - 1) -
                                                (i % group_size), 0, j],
                                watch_keypoints[i + (group_size - 1) -
                                                (i % group_size), 1, j]),
                               colors[j], cv2.MARKER_DIAMOND, 12, 2,
                               cv2.LINE_AA)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=4, padding=5),
              opts=opts)
示例#9
0
def vis_fn_dual(vis, train_loss, val_pck, train_lr, epoch, num_epochs, dataloader, theta, thetai, results, title, use_cuda=True):
    # Visualize watch images
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)
    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    pt = PointTnf(use_cuda=use_cuda)

    group_size = 4
    watch_images = torch.ones(len(dataloader) * group_size, 3, 280, 240)
    watch_keypoints = -torch.ones(len(dataloader) * group_size, 2, 20)
    if use_cuda:
        watch_images = watch_images.cuda()
        watch_keypoints = watch_keypoints.cuda()
    num_points = np.ones(len(dataloader) * group_size).astype(np.int8)
    correct_index = list()
    image_names = list()
    metrics = list()

    # Colors for keypoints
    cmap = plt.get_cmap('tab20')
    colors = list()
    for c in range(20):
        r = cmap(c)[0] * 255
        g = cmap(c)[1] * 255
        b = cmap(c)[2] * 255
        colors.append((b, g, r))
    fnt = cv2.FONT_HERSHEY_COMPLEX

    theta, thetai = swap(theta, thetai)
    # means for normalize of caffe resnet and vgg
    # pixel_means = torch.Tensor(np.array([[[[102.9801, 115.9465, 122.7717]]]]).astype(np.float32))
    for batch_idx, batch in enumerate(dataloader):
        if use_cuda:
            batch = batch_cuda(batch)

        batch['source_image'], batch['target_image'] = swap(batch['source_image'], batch['target_image'])
        batch['source_im_info'], batch['target_im_info'] = swap(batch['source_im_info'], batch['target_im_info'])
        batch['source_points'], batch['target_points'] = swap(batch['source_points'], batch['target_points'])

        # Theta and thetai
        theta_aff = theta['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps = theta['afftps'][batch_idx].unsqueeze(0)

        theta_aff_inver = thetai['aff'][batch_idx].unsqueeze(0)
        theta_aff_tps_inver = thetai['afftps'][batch_idx].unsqueeze(0)

        # Warped image
        warped_aff = affTnf(batch['source_image'], theta_aff)
        warped_aff_tps = tpsTnf(warped_aff, theta_aff_tps)

        watch_images[batch_idx * group_size, :, 0:240, :] = batch['source_image']
        watch_images[batch_idx * group_size + 1, :, 0:240, :] = warped_aff
        watch_images[batch_idx * group_size + 2, :, 0:240, :] = warped_aff_tps
        watch_images[batch_idx * group_size + 3, :, 0:240, :] = batch['target_image']

        # Warped keypoints
        source_im_size = batch['source_im_info'][:, 0:3]
        target_im_size = batch['target_im_info'][:, 0:3]

        source_points = batch['source_points']
        target_points = batch['target_points']

        source_points_norm = PointsToUnitCoords(P=source_points, im_size=source_im_size)
        target_points_norm = PointsToUnitCoords(P=target_points, im_size=target_im_size)

        warped_points_aff_norm = pt.affPointTnf(theta=theta_aff_inver, points=source_points_norm)
        warped_points_aff = PointsToPixelCoords(P=warped_points_aff_norm, im_size=target_im_size)
        _, index_aff, N_pts = pck(target_points, warped_points_aff)
        warped_points_aff = relocate(warped_points_aff, target_im_size)

        warped_points_aff_tps_norm = pt.tpsPointTnf(theta=theta_aff_tps_inver, points=source_points_norm)
        warped_points_aff_tps_norm = pt.affPointTnf(theta=theta_aff_inver, points=warped_points_aff_tps_norm)
        warped_points_aff_tps = PointsToPixelCoords(P=warped_points_aff_tps_norm, im_size=target_im_size)
        _, index_aff_tps, _ = pck(target_points, warped_points_aff_tps)
        warped_points_aff_tps = relocate(warped_points_aff_tps, target_im_size)

        watch_keypoints[batch_idx * group_size, :, :N_pts] = relocate(batch['source_points'], source_im_size)[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size + 1, :, :N_pts] = warped_points_aff[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size + 2, :, :N_pts] = warped_points_aff_tps[:, :, :N_pts]
        watch_keypoints[batch_idx * group_size + 3, :, :N_pts] = relocate(batch['target_points'], target_im_size)[:, :, :N_pts]

        num_points[batch_idx * group_size:batch_idx * group_size + group_size] = N_pts

        correct_index.append(np.arange(N_pts))
        correct_index.append(index_aff)
        correct_index.append(index_aff_tps)
        correct_index.append(np.arange(N_pts))

        image_names.append('Source')
        image_names.append('Aff')
        image_names.append('AffTPS')
        image_names.append('Target')


        metrics.append('')
        metrics.append('PCK: {:.2%}'.format(float(results['aff']['pck'][batch_idx])))
        metrics.append('PCK: {:.2%}'.format(float(results['afftps']['pck'][batch_idx])))
        metrics.append('')


    opts = dict(jpgquality=100, title='Epoch ' + str(epoch) + ' source warped target')
    # Un-normalize for caffe resnet and vgg
    # watch_images = watch_images.permute(0, 2, 3, 1) + pixel_means
    # watch_images = watch_images[:, :, :, [2, 1, 0]].permute(0, 3, 1, 2)
    # watch_images = normalize_image(watch_images, forward=False) * 255.0
    watch_images[:, :, 0:240, :] = normalize_image(watch_images[:, :, 0:240, :], forward=False)
    watch_images *= 255.0
    watch_images = watch_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
    watch_keypoints = watch_keypoints.cpu().numpy()

    for i in range(watch_images.shape[0]):
        pos_name = (80, 255)
        if (i + 1) % group_size == 1 or (i + 1) % group_size == 0:
            pos_pck = (0, 0)
        else:
            pos_pck = (70, 275)
        cv2.putText(watch_images[i], image_names[i], pos_name, fnt, 0.5, (0, 0, 0), 1)
        cv2.putText(watch_images[i], metrics[i], pos_pck, fnt, 0.5, (0, 0, 0), 1)
        if (i + 1) % group_size == 0:
            for j in range(num_points[i]):
                cv2.drawMarker(watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j],
                               cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)
        else:
            for j in correct_index[i]:
                cv2.drawMarker(watch_images[i], (watch_keypoints[i, 0, j], watch_keypoints[i, 1, j]), colors[j],
                               cv2.MARKER_DIAMOND, 12, 2, cv2.LINE_AA)
                cv2.drawMarker(watch_images[i],
                               (watch_keypoints[i + (group_size - 1) - (i % group_size), 0, j], watch_keypoints[i + (group_size - 1) - (i % group_size), 1, j]),
                               colors[j], cv2.MARKER_CROSS, 12, 2, cv2.LINE_AA)

    watch_images = torch.Tensor(watch_images.astype(np.float32))
    watch_images = watch_images.permute(0, 3, 1, 2)
    vis.image(torchvision.utils.make_grid(watch_images, nrow=4, padding=3), opts=opts)

    if epoch == num_epochs:
        epochs = np.arange(1, num_epochs+1)
        # Visualize train loss
        opts_loss = dict(xlabel='Epoch',
                    ylabel='Loss',
                    title='GM ResNet101 ' + title + ' Training Loss',
                    legend=['Loss'],
                    width=2000)
        vis.line(train_loss, epochs, opts=opts_loss)

        # Visualize val pck
        opts_pck = dict(xlabel='Epoch',
                    ylabel='Val PCK',
                    title='GM ResNet101 ' + title + ' Val PCK',
                    legend=['PCK'],
                    width=2000)
        vis.line(val_pck, epochs, opts=opts_pck)

        # Visualize train lr
        opts_lr = dict(xlabel='Epoch',
                       ylabel='Learning Rate',
                       title='GM ResNet101 ' + title + ' Training Learning Rate',
                       legend=['LR'],
                       width=2000)
        vis.line(train_lr, epochs, opts=opts_lr)
示例#10
0
def flow_metrics(batch,
                 batch_start_idx,
                 theta_aff,
                 theta_tps,
                 theta_aff_tps,
                 stats,
                 args,
                 use_cuda=True):
    result_path = args.flow_output_dir

    do_aff = theta_aff is not None
    do_tps = theta_tps is not None
    do_aff_tps = theta_aff_tps is not None

    pt = PointTnf(use_cuda=use_cuda)

    batch_size = batch['source_im_size'].size(0)
    for b in range(batch_size):
        h_src = int(batch['source_im_size'][b, 0].data.cpu().numpy())
        w_src = int(batch['source_im_size'][b, 1].data.cpu().numpy())
        h_tgt = int(batch['target_im_size'][b, 0].data.cpu().numpy())
        w_tgt = int(batch['target_im_size'][b, 1].data.cpu().numpy())

        grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, w_tgt),
                                     np.linspace(-1, 1, h_tgt))
        grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)
        grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)
        grid_X = Variable(grid_X, requires_grad=False)
        grid_Y = Variable(grid_Y, requires_grad=False)
        if use_cuda:
            grid_X = grid_X.cuda()
            grid_Y = grid_Y.cuda()

        grid_X_vec = grid_X.view(1, 1, -1)
        grid_Y_vec = grid_Y.view(1, 1, -1)

        grid_XY_vec = torch.cat((grid_X_vec, grid_Y_vec), 1)

        def pointsToGrid(x, h_tgt=h_tgt, w_tgt=w_tgt):
            return x.contiguous().view(1, 2, h_tgt,
                                       w_tgt).transpose(1, 2).transpose(2, 3)

        idx = batch_start_idx + b

        if do_aff:
            grid_aff = pointsToGrid(
                pt.affPointTnf(theta_aff[b, :].unsqueeze(0), grid_XY_vec))
            flow_aff = th_sampling_grid_to_np_flow(source_grid=grid_aff,
                                                   h_src=h_src,
                                                   w_src=w_src)
            flow_aff_path = os.path.join(result_path, 'aff',
                                         batch['flow_path'][b])
            create_file_path(flow_aff_path)
            write_flo_file(flow_aff, flow_aff_path)
        if do_tps:
            grid_tps = pointsToGrid(
                pt.tpsPointTnf(theta_tps[b, :].unsqueeze(0), grid_XY_vec))
            flow_tps = th_sampling_grid_to_np_flow(source_grid=grid_tps,
                                                   h_src=h_src,
                                                   w_src=w_src)
            flow_tps_path = os.path.join(result_path, 'tps',
                                         batch['flow_path'][b])
            create_file_path(flow_tps_path)
            write_flo_file(flow_tps, flow_tps_path)
        if do_aff_tps:
            grid_aff_tps = pointsToGrid(
                pt.affPointTnf(
                    theta_aff[b, :].unsqueeze(0),
                    pt.tpsPointTnf(theta_aff_tps[b, :].unsqueeze(0),
                                   grid_XY_vec)))
            flow_aff_tps = th_sampling_grid_to_np_flow(
                source_grid=grid_aff_tps, h_src=h_src, w_src=w_src)
            flow_aff_tps_path = os.path.join(result_path, 'aff_tps',
                                             batch['flow_path'][b])
            create_file_path(flow_aff_tps_path)
            write_flo_file(flow_aff_tps, flow_aff_tps_path)

        idx = batch_start_idx + b
    return stats
示例#11
0
def pck_metric(batch,
               batch_start_idx,
               theta_aff,
               theta_tps,
               theta_aff_tps,
               stats,
               args,
               use_cuda=True):
    alpha = args.pck_alpha
    do_aff = theta_aff is not None
    do_tps = theta_tps is not None
    do_aff_tps = theta_aff_tps is not None

    source_im_size = batch['source_im_size']
    target_im_size = batch['target_im_size']

    source_points = batch['source_points']
    target_points = batch['target_points']

    # Instantiate point transformer
    # pt = PointTnf(use_cuda=use_cuda, tps_reg_factor=args.tps_reg_factor)
    pt = PointTnf(use_cuda=use_cuda)

    # warp points with estimated transformations
    target_points_norm = PointsToUnitCoords(target_points, target_im_size)

    if do_aff:
        # do affine only
        warped_points_aff_norm = pt.affPointTnf(theta_aff, target_points_norm)
        warped_points_aff = PointsToPixelCoords(warped_points_aff_norm,
                                                source_im_size)

    if do_tps:
        # do tps only
        warped_points_tps_norm = pt.tpsPointTnf(theta_tps, target_points_norm)
        warped_points_tps = PointsToPixelCoords(warped_points_tps_norm,
                                                source_im_size)

    if do_aff_tps:
        # do tps+affine
        warped_points_aff_tps_norm = pt.tpsPointTnf(theta_aff_tps,
                                                    target_points_norm)
        warped_points_aff_tps_norm = pt.affPointTnf(
            theta_aff, warped_points_aff_tps_norm)
        warped_points_aff_tps = PointsToPixelCoords(warped_points_aff_tps_norm,
                                                    source_im_size)

    L_pck = batch['L_pck'].data

    current_batch_size = batch['source_im_size'].size(0)
    indices = range(batch_start_idx, batch_start_idx + current_batch_size)

    # import pdb; pdb.set_trace()

    if do_aff:
        pck_aff = pck(source_points.data, warped_points_aff.data, L_pck, alpha)

    if do_tps:
        pck_tps = pck(source_points.data, warped_points_tps.data, L_pck, alpha)

    if do_aff_tps:
        pck_aff_tps = pck(source_points.data, warped_points_aff_tps.data,
                          L_pck, alpha)

    if do_aff:
        stats['aff']['pck'][indices] = pck_aff.unsqueeze(1).cpu().numpy()
    if do_tps:
        stats['tps']['pck'][indices] = pck_tps.unsqueeze(1).cpu().numpy()
    if do_aff_tps:
        stats['aff_tps']['pck'][indices] = pck_aff_tps.unsqueeze(
            1).cpu().numpy()

    return stats