Beispiel #1
0
def visualise_projection(self,
                         points,
                         world_mat,
                         camera_mat,
                         img,
                         output_file='out.png'):
    r''' Visualizes the transformation and projection to image plane.

        The first points of the batch are transformed and projected to the
        respective image. After performing the relevant transformations, the
        visualization is saved in the provided output_file path.

    Arguments:
        points (tensor): batch of point cloud points
        world_mat (tensor): batch of matrices to rotate pc to camera-based
                coordinates
        camera_mat (tensor): batch of camera matrices to project to 2D image
                plane
        img (tensor): tensor of batch GT image files
        output_file (string): where the output should be saved
    '''
    points_transformed = common.transform_points(points, world_mat)
    points_img = common.project_to_camera(points_transformed, camera_mat)
    pimg2 = points_img[0].detach().cpu().numpy()
    image = img[0].cpu().numpy()
    plt.imshow(image.transpose(1, 2, 0))
    plt.plot((pimg2[:, 0] + 1) * image.shape[1] / 2,
             (pimg2[:, 1] + 1) * image.shape[2] / 2, 'x')
    plt.savefig(output_file)
    def forward_local_second_step(self, data, c, local_feat_maps, pts):
        world_mat = data['world_mat']
        camera_mat = data['camera_mat']

        assert self.local
        pts = common.transform_points(pts, world_mat)
        points_img = common.project_to_camera(pts, camera_mat)
        points_img = points_img.unsqueeze(1)

        # get local feats
        local_feats = []
        for f in local_feat_maps:
            #f = f.detach()
            f = F.grid_sample(f, points_img, mode='nearest')
            f = f.squeeze(2)
            local_feats.append(f)

        local_feats = torch.cat(local_feats, dim=1)
        local_feats = local_feats.transpose(1, 2)  # batch * n_pts * f_dim

        local_feats = self.local_fc(local_feats)

        # x: B * c_dim
        # local: feats B * n_pts * c_dim
        return c, local_feats
Beispiel #3
0
    def encode_second_step(self, f3, f2, f1, pts, world_mat, camera_mat):
        pts = common.transform_points(pts, world_mat)
        points_img = common.project_to_camera(pts, camera_mat)
        points_img = points_img.unsqueeze(1)

        f2 = f2.detach()
        f2 = F.relu(f2)
        f2 = F.grid_sample(f2, points_img)
        f2 = f2.squeeze(2)
        f2 = self.f2_conv(f2)

        f1 = f1.detach()
        f1 = F.relu(f1)
        f1 = F.grid_sample(f1, points_img)
        f1 = f1.squeeze(2)
        f1 = self.f1_conv(f1)
        
        f3 = self.fc3(f3)

        if self.batch_norm:
            f3 = self.f3_bn(f3)
            f2 = self.f2_bn(f2)
            f1 = self.f1_bn(f1)

        f2 = f2.transpose(1, 2)
        f1 = f1.transpose(1, 2)
        # f2 : batch * n_pts * fmap_dim
        # f1 : batch * n_pts * fmap_dim
        return f3, f2, f1
Beispiel #4
0
    def forward(self, x, fm, camera_mat, img=None, visualise=False):
        ''' Performs a forward pass through the GP layer.

        Args:
            x (tensor): coordinates of shape (batch_size, num_vertices, 3)
            f (list): list of feature maps from where the image features
                        should be pooled
            camera_mat (tensor): camera matrices for transformation to 2D
                        image plane
            img (tensor): images (just fo visualisation purposes)
        '''
        points_img = common.project_to_camera(x, camera_mat)
        points_img = points_img.unsqueeze(1)
        feats = []
        feats.append(x)
        for fmap in fm:
            # bilinearly interpolate to get the corresponding features
            feat_pts = F.grid_sample(fmap, points_img)
            feat_pts = feat_pts.squeeze(2)
            feats.append(feat_pts.transpose(1, 2))
        # Just for visualisation purposes
        if visualise and (img is not None):
            self.visualise_projection(
                points_img.squeeze(1)[0].detach().cpu().numpy(),
                img[0].cpu().numpy())

        outputs = torch.cat([proj for proj in feats], dim=2)
        return outputs
Beispiel #5
0
    def forward_local(self, data, pts):
        assert self.local
        world_mat = data['world_mat']
        camera_mat = data['camera_mat']
        x = data[None]

        pts = transform_points(pts, world_mat)
        points_img = project_to_camera(pts, camera_mat)
        points_img = points_img.unsqueeze(1)

        local_feat_maps = []
        if self.normalize:
            x = normalize_imagenet(x)

        x = self.features.conv1(x)
        x = self.features.bn1(x)
        x = self.features.relu(x)
        x = self.features.maxpool(x)  # 64 * 112 * 112

        x = self.features.layer1(x)
        local_feat_maps.append(x)  # 64 * 56 * 56
        x = self.features.layer2(x)
        local_feat_maps.append(x)  # 128 * 28 * 28
        x = self.features.layer3(x)
        local_feat_maps.append(x)  # 256 * 14 * 14
        x = self.features.layer4(x)
        local_feat_maps.append(x)  # 512 * 7 * 7

        x = self.features.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        # get local feats
        local_feats = []
        for f in local_feat_maps:
            #f = f.detach()
            f = F.grid_sample(f, points_img, mode='nearest')
            f = f.squeeze(2)
            local_feats.append(f)

        local_feats = torch.cat(local_feats, dim=1)
        local_feats = local_feats.transpose(1, 2)  # batch * n_pts * f_dim

        local_feats = self.local_fc(local_feats)

        # x: B * c_dim
        # local: feats B * n_pts * c_dim
        return x, local_feats
Beispiel #6
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer,
        )

        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)

        if self.model.encoder_world_mat is not None:
            out = self.model(encoder_inputs, world_mat=world_mat)
        else:
            out = self.model(encoder_inputs)

        loss = 0
        if isinstance(out, tuple):
            out, trans_feat = out

            if isinstance(self.model.encoder, PointNetEncoder) or isinstance(
                    self.model.encoder, PointNetResEncoder):
                loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat)

        # chamfer distance loss
        if self.loss_type == 'cd':
            dist1, dist2 = cd.chamfer_distance(out, gt_pc)
            loss = (dist1.mean(1) + dist2.mean(1)).mean() / 2.
        else:
            out_pts_count = out.size(1)
            loss = loss + (emd.earth_mover_distance(
                out, gt_pc, transpose=False) / out_pts_count).mean()

        # view penalty loss
        if self.view_penalty:
            gt_mask_flow = data.get('inputs.mask_flow').to(
                device)  # B * 1 * H * W
            if world_mat is None:
                world_mat = get_world_mat(data, device=device)
            camera_mat = get_camera_mat(data, device=device)

            # projection use world mat & camera mat
            if self.gt_pointcloud_transfer == 'world_scale_model':
                out_pts = transform_points(out, world_mat)
            elif self.gt_pointcloud_transfer == 'view_scale_model':
                t = world_mat[:, :, 3:]
                out_pts = out_pts + t
            elif self.gt_pointcloud_transfer == 'view':
                t = world_mat[:, :, 3:]
                out_pts = out_pts * t[:, 2:, :]
                out_pts = out_pts + t
            else:
                raise NotImplementedError

            out_pts_img = project_to_camera(out_pts, camera_mat)
            out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

            out_mask_flow = F.grid_sample(gt_mask_flow,
                                          out_pts_img)  # B * 1 * 1 * n_pts
            loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow,
                                    inplace=True).mean()
            loss = loss + self.loss_mask_flow_ratio * loss_mask_flow

            if self.view_penalty == 'mask_flow_and_depth':
                # depth test loss
                t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1)
                gt_mask = data.get('inputs.mask').byte().to(device)
                depth_pred = data.get('inputs.depth_pred').to(
                    device) * t_scale  # absolute depth from view
                background_setting(depth_pred, gt_mask)
                depth_z = out_pts[:, :, 2:].transpose(1, 2)
                corresponding_z = F.grid_sample(
                    depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                corresponding_z = corresponding_z.squeeze(1)

                # eps
                loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                         corresponding_z,
                                         inplace=True).mean()
                loss = loss + self.loss_depth_test_ratio * loss_depth_test

        return loss
Beispiel #7
0
    def eval_step(self, data):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer)
        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)
        batch_size = gt_pc.size(0)

        with torch.no_grad():
            if self.model.encoder_world_mat is not None:
                out = self.model(encoder_inputs, world_mat=world_mat)
            else:
                out = self.model(encoder_inputs)

            if isinstance(out, tuple):
                out, trans_feat = out

            eval_dict = {}
            if batch_size == 1:
                pointcloud_hat = out.cpu().squeeze(0).numpy()
                pointcloud_gt = gt_pc.cpu().squeeze(0).numpy()

                eval_dict = self.mesh_evaluator.eval_pointcloud(
                    pointcloud_hat, pointcloud_gt)

            # chamfer distance loss
            if self.loss_type == 'cd':
                dist1, dist2 = cd.chamfer_distance(out, gt_pc)
                loss = (dist1.mean(1) + dist2.mean(1)) / 2.
            else:
                loss = emd.earth_mover_distance(out, gt_pc, transpose=False)

            if self.gt_pointcloud_transfer in ('world_scale_model',
                                               'view_scale_model', 'view'):
                pointcloud_scale = data.get('pointcloud.scale').to(
                    device).view(batch_size, 1, 1)
                loss = loss / (pointcloud_scale**2)
                if self.gt_pointcloud_transfer == 'view':
                    if world_mat is None:
                        world_mat = get_world_mat(data, device=device)
                    t_scale = world_mat[:, 2:, 3:]
                    loss = loss * (t_scale**2)

            if self.loss_type == 'cd':
                loss = loss.mean()
                eval_dict['chamfer'] = loss.item()
            else:
                out_pts_count = out.size(1)
                loss = (loss / out_pts_count).mean()
                eval_dict['emd'] = loss.item()

            # view penalty loss
            if self.view_penalty:
                gt_mask_flow = data.get('inputs.mask_flow').to(
                    device)  # B * 1 * H * W
                if world_mat is None:
                    world_mat = get_world_mat(data, device=device)
                camera_mat = get_camera_mat(data, device=device)

                # projection use world mat & camera mat
                if self.gt_pointcloud_transfer == 'world_scale_model':
                    out_pts = transform_points(out, world_mat)
                elif self.gt_pointcloud_transfer == 'view_scale_model':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts + t
                elif self.gt_pointcloud_transfer == 'view':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts * t[:, 2:, :]
                    out_pts = out_pts + t
                else:
                    raise NotImplementedError

                out_pts_img = project_to_camera(out_pts, camera_mat)
                out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

                out_mask_flow = F.grid_sample(gt_mask_flow,
                                              out_pts_img)  # B * 1 * 1 * n_pts
                loss_mask_flow = F.relu(1. - self.mask_flow_eps -
                                        out_mask_flow,
                                        inplace=True).mean()
                loss = self.loss_mask_flow_ratio * loss_mask_flow
                eval_dict['loss_mask_flow'] = loss.item()

                if self.view_penalty == 'mask_flow_and_depth':
                    # depth test loss
                    t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1,
                                                      1)
                    gt_mask = data.get('inputs.mask').byte().to(device)
                    depth_pred = data.get('inputs.depth_pred').to(
                        device) * t_scale

                    background_setting(depth_pred, gt_mask)
                    depth_z = out_pts[:, :, 2:].transpose(1, 2)
                    corresponding_z = F.grid_sample(
                        depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                    corresponding_z = corresponding_z.squeeze(1)

                    # eps = 0.05
                    loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                             corresponding_z,
                                             inplace=True).mean()
                    loss = self.loss_depth_test_ratio * loss_depth_test
                    eval_dict['loss_depth_test'] = loss.item()

        return eval_dict