def generate_mesh(self, data, fix_normals=False):
        ''' Generates a mesh.

        Arguments:
            data (tensor): input data
            fix_normals (boolean): if normals should be fixed
        '''

        img = data.get('inputs').to(self.device)
        camera_args = common.get_camera_args(
            data, 'pointcloud.loc', 'pointcloud.scale', device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']
        with torch.no_grad():
            outputs1, outputs2 = self.model(img, camera_mat)
            out_1, out_2, out_3 = outputs1

        transformed_pred = common.transform_points_back(out_3, world_mat)
        vertices = transformed_pred.squeeze().cpu().numpy()

        faces = self.base_mesh[:, 1:]  # remove the f's in the first column
        faces = faces.astype(int) - 1  # To adjust indices to trimesh notation
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
        if fix_normals:
            # Fix normals due to wrong base ellipsoid
            trimesh.repair.fix_normals(mesh)
        return mesh
예제 #2
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        if self.binary_occ:
            occ = (data.get('points.occ') >= 0.5).float().to(device)
        else:
            occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)
        kwargs = {}

        if self.use_local_feature:
            camera_args = get_camera_args(data, 'points.loc', 'points.scale', device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']
            f3,f2,f1 = self.model.encode_inputs(inputs,p,Rt,K)
        else:
            f3,f2,f1 = self.model.encode_inputs(inputs)
        
        q_z = self.model.infer_z(p, occ, f3, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        p_r = self.model.decode(p, z, f3, f2, f1, **kwargs)
        logits = p_r.logits
        probs = p_r.probs
        if self.loss_type == 'cross_entropy':
            loss_i = F.binary_cross_entropy_with_logits(
                logits, occ, reduction='none')
        elif self.loss_type == 'l2':
            logits = F.sigmoid(logits)
            loss_i = torch.pow((logits - occ), 2)
        elif self.loss_type == 'l1':
            logits = F.sigmoid(logits)
            loss_i = torch.abs(logits - occ)
        else:
            logits = F.sigmoid(logits)
            loss_i = F.binary_cross_entropy(logits, occ, reduction='none')

        if self.loss_tolerance_episolon != 0.:
            loss_i = torch.clamp(loss_i, min=self.loss_tolerance_episolon, max=100)

        if self.sign_lambda != 0.:
            w = 1. - self.sign_lambda * torch.sign(occ - 0.5) * torch.sign(probs - self.threshold)
            loss_i = loss_i * w

        if self.surface_loss_weight != 1.:
            w = ((occ > 0.) & (occ < 1.)).float()
            w = w * (self.surface_loss_weight - 1) + 1
            loss_i = loss_i * w
        loss = loss + loss_i.sum(-1).mean()
        return loss
예제 #3
0
    def train_step(self, data):
        r''' Performs a training step of the model.

        Arguments:
            data (tensor): The input data
        '''

        self.model.train()
        points = data.get('pointcloud').to(self.device)
        normals = data.get('pointcloud.normals').to(self.device)
        img = data.get('inputs').to(self.device)
        camera_args = common.get_camera_args(data,
                                             'pointcloud.loc',
                                             'pointcloud.scale',
                                             device=self.device)

        # Transform GT data into camera coordinate system
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']
        points_transformed = common.transform_points(points, world_mat)
        # Transform GT normals to camera coordinate system
        world_normal_mat = world_mat[:, :, :3]
        normals = common.transform_points(normals, world_normal_mat)

        outputs1, outputs2 = self.model(img, camera_mat)
        loss = self.compute_loss(outputs1, outputs2, points_transformed,
                                 normals, img)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()
예제 #4
0
    def visualize(self, data):
        r''' Visualises the GT point cloud and predicted vertices (as a point cloud).

        Arguments:
            data (tensor): input data
        '''

        points_gt = data.get('pointcloud').to(self.device)
        img = data.get('inputs').to(self.device)
        camera_args = common.get_camera_args(data,
                                             'pointcloud.loc',
                                             'pointcloud.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']

        if not os.path.isdir(self.vis_dir):
            os.mkdir(self.vis_dir)

        with torch.no_grad():
            outputs1, outputs2 = self.model(img, camera_mat)

        pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1
        points_out = common.transform_points_back(pred_vertices_3, world_mat)
        points_out = points_out.cpu().numpy()
        input_img_path = os.path.join(self.vis_dir, 'input.png')
        save_image(img.cpu(), input_img_path, nrow=4)

        points_gt = points_gt.cpu().numpy()
        batch_size = img.size(0)
        for i in range(batch_size):
            out_file = os.path.join(self.vis_dir, '%03d.png' % i)
            out_file_gt = os.path.join(self.vis_dir, '%03d_gt.png' % i)
            vis.visualize_pointcloud(points_out[i], out_file=out_file)
            vis.visualize_pointcloud(points_gt[i], out_file=out_file_gt)
예제 #5
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''

        npp = 0.1

        device = self.device
        p = data.get('points').to(device)
        occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)
        inputs = self.colornoise(inputs, npp)

        world_mat = data.get('inputs.world_mat').to(device)
        camera_mat = data.get('inputs.camera_mat').to(device)
        camera_args = common.get_camera_args(data,
                                             'points.loc',
                                             'points.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']

        self.vis(False, inputs[0].cpu().numpy())
        # print("world_mat",world_mat.shape)
        # print("camera_mat",camera_mat.shape)
        # exit(1)

        kwargs = {}

        G, c = self.model.encode_inputs(inputs)
        # print("c0",c[0].shape) 64, 56, 56
        # print("c1",c[1].shape) 128, 28, 28
        # print("c2",c[2].shape) 256, 14, 14
        # print("c3",c[3].shape) 512, 7, 7
        # print("c4",c[4].shape) 256, 2, 2
        # print("G",G.shape) 1024
        v = self.model.gproj(p, G, c, world_mat, camera_mat, inputs, False)
        # v = self.model.gproj(p, c, camera_mat, inputs, True)
        # v point number, 1219
        # v+G point number, 2243
        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        logits = self.model.decode(p, z, v, **kwargs).logits

        # exit(1)

        loss_i = F.binary_cross_entropy_with_logits(logits,
                                                    occ,
                                                    reduction='none')
        loss = loss + loss_i.sum(-1).mean()

        return loss
예제 #6
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)
        I = data.get('inputs.image', torch.empty(p.size(0), 0)).to(device)

        world_mat = data.get('inputs.world_mat').to(device)
        camera_mat = data.get('inputs.camera_mat').to(device)
        camera_args = common.get_camera_args(data,
                                             'points.loc',
                                             'points.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']

        # print("world_mat",world_mat.shape)
        # print("camera_mat",camera_mat.shape)
        # exit(1)

        kwargs = {}

        c = self.model.encode_inputs(inputs)
        # print("c",c[0].shape)
        # print("c",c[1].shape)
        # print("c",c[2].shape)
        # print("c",c[3].shape)
        v = self.model.gproj(p, c, world_mat, camera_mat, inputs, False)
        # v = self.model.gproj(p, c, camera_mat, inputs, True)

        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        logits = self.model.decode(p, z, v, **kwargs).logits

        # exit(1)

        loss_i = F.binary_cross_entropy_with_logits(logits,
                                                    occ,
                                                    reduction='none')
        loss = loss + loss_i.sum(-1).mean()

        return loss
예제 #7
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        if self.binary_occ:
            occ = (data.get('points.occ') >= 0.5).float().to(device)
        else:
            occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)
        kwargs = {}

        if self.use_local_feature:
            camera_args = get_camera_args(data,
                                          'points.loc',
                                          'points.scale',
                                          device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']
            f3, f2, f1 = self.model.encode_inputs(inputs, p, Rt, K)
        else:
            f3, f2, f1 = self.model.encode_inputs(inputs)

        q_z = self.model.infer_z(p, occ, f3, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        p_r = self.model.decode(p, z, f3, f2, f1, **kwargs)
        logits = p_r.logits
        probs = p_r.probs

        # loss
        loss_i = get_occ_loss(logits, occ, self.loss_type)
        # loss strategies
        loss_i = occ_loss_postprocess(loss_i, occ, probs,
                                      self.loss_tolerance_episolon,
                                      self.sign_lambda, self.threshold,
                                      self.surface_loss_weight)

        loss = loss + loss_i.sum(-1).mean()

        return loss
예제 #8
0
    def generate_mesh(self, data, return_stats=True):
        ''' Generates the output mesh.

        Args:
            data (tensor): data tensor
            return_stats (bool): whether stats should be returned
        '''
        self.model.eval()
        device = self.device
        stats_dict = {}

        world_mat = data.get('inputs.world_mat').to(device)
        camera_mat = data.get('inputs.camera_mat').to(device)
        camera_args = common.get_camera_args(data,
                                             'points.loc',
                                             'points.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']

        inputs = data.get('inputs', torch.empty(1, 0)).to(device)
        kwargs = {}

        # Preprocess if requires
        if self.preprocessor is not None:
            t0 = time.time()
            with torch.no_grad():
                inputs = self.preprocessor(inputs)
            stats_dict['time (preprocess)'] = time.time() - t0

        # Encode inputs
        t0 = time.time()
        with torch.no_grad():
            G, c = self.model.encode_inputs(inputs)
        stats_dict['time (encode inputs)'] = time.time() - t0

        z = self.model.get_z_from_prior((1, ), sample=self.sample).to(device)
        mesh = self.generate_from_latent(z,
                                         G,
                                         c,
                                         stats_dict=stats_dict,
                                         world_mat=world_mat,
                                         camera_mat=camera_mat,
                                         **kwargs)

        if return_stats:
            return mesh, stats_dict
        else:
            return mesh
예제 #9
0
    def eval_step(self, data):
        r''' Performs an evaluation step.

        Arguments:
            data (tensor): input data
        '''
        self.model.eval()
        points = data.get('pointcloud').to(self.device)
        img = data.get('inputs').to(self.device)
        normals = data.get('pointcloud.normals').to(self.device)

        # Transform GT points to camera coordinates
        camera_args = common.get_camera_args(data,
                                             'pointcloud.loc',
                                             'pointcloud.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']
        points_transformed = common.transform_points(points, world_mat)
        # Transform GT normals to camera coordinates
        world_normal_mat = world_mat[:, :, :3]
        normals = common.transform_points(normals, world_normal_mat)

        with torch.no_grad():
            outputs1, outputs2 = self.model(img, camera_mat)

        pred_vertices_1, pred_vertices_2, pred_vertices_3 = outputs1

        loss = self.compute_loss(outputs1, outputs2, points_transformed,
                                 normals, img)
        lc1, lc2, id31, id32 = chamfer_distance(pred_vertices_3,
                                                points_transformed,
                                                give_id=True)
        l_c = (lc1 + lc2).mean()
        l_e = self.edge_length_loss(pred_vertices_3, 3)
        l_n = self.normal_loss(pred_vertices_3, normals, id31, 3)
        l_l, move_loss = self.laplacian_loss(pred_vertices_3,
                                             outputs2[2],
                                             block_id=3)

        eval_dict = {
            'loss': loss.item(),
            'chamfer': l_c.item(),
            'edge': l_e.item(),
            'normal': l_n.item(),
            'laplace': l_l.item(),
            'move': move_loss.item()
        }
        return eval_dict
예제 #10
0
    def visualize(self, data):
        ''' Performs a visualization step for the data.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device

        batch_size = data['points'].size(0)
        inputs = data.get('inputs', torch.empty(batch_size, 0)).to(device)

        if self.use_local_feature:
            camera_args = get_camera_args(data,
                                          'points.loc',
                                          'points.scale',
                                          device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']

        shape = (32, 32, 32)
        p = make_3d_grid([-0.5] * 3, [0.5] * 3, shape).to(device)
        p = p.expand(batch_size, *p.size())

        kwargs = {}
        with torch.no_grad():
            if self.use_local_feature:
                p_r = self.model(p,
                                 inputs,
                                 Rt,
                                 K,
                                 sample=self.eval_sample,
                                 **kwargs)
            else:
                p_r = self.model(p, inputs, sample=self.eval_sample, **kwargs)

        occ_hat = p_r.probs.view(batch_size, *shape)
        voxels_out = (occ_hat >= self.threshold).cpu().numpy()

        for i in trange(batch_size):
            input_img_path = os.path.join(self.vis_dir, '%03d_in.png' % i)
            vis.visualize_data(inputs[i].cpu(), self.input_type,
                               input_img_path)
            vis.visualize_voxels(voxels_out[i],
                                 os.path.join(self.vis_dir, '%03d.png' % i))
    def generate_pointcloud(self, data):
        ''' Generates a pointcloud by only returning the vertices

        Arguments:
            data (tensor): input data
        '''

        img = data.get('inputs').to(self.device)
        camera_args = common.get_camera_args(
            data, 'pointcloud.loc', 'pointcloud.scale', device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']

        with torch.no_grad():
            outputs1, _ = self.model(img, camera_mat)
            _, _, out_3 = outputs1
        transformed_pred = common.transform_points_back(out_3, world_mat)
        pc_out = transformed_pred.squeeze().cpu().numpy()
        return pc_out
예제 #12
0
    def eval_step(self, data):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device
        threshold = self.threshold
        eval_dict = {}

        # Compute elbo
        points = data.get('points').to(device)
        occ = data.get('points.occ').to(device)

        inputs = data.get('inputs', torch.empty(points.size(0), 0)).to(device)
        voxels_occ = data.get('voxels')

        points_iou = data.get('points_iou').to(device)
        occ_iou = data.get('points_iou.occ').to(device)

        world_mat = data.get('inputs.world_mat').to(device)
        camera_mat = data.get('inputs.camera_mat').to(device)
        camera_args = common.get_camera_args(data,
                                             'points.loc',
                                             'points.scale',
                                             device=self.device)
        world_mat, camera_mat = camera_args['Rt'], camera_args['K']

        kwargs = {}

        with torch.no_grad():
            elbo, rec_error, kl = self.model.compute_elbo(
                points, occ, inputs, world_mat, camera_mat, **kwargs)

        eval_dict['loss'] = -elbo.mean().item()
        eval_dict['rec_error'] = rec_error.mean().item()
        eval_dict['kl'] = kl.mean().item()

        # Compute iou
        batch_size = points.size(0)

        with torch.no_grad():
            p_out = self.model(points_iou,
                               inputs,
                               world_mat,
                               camera_mat,
                               sample=self.eval_sample,
                               **kwargs)

        occ_iou_np = (occ_iou >= 0.5).cpu().numpy()
        occ_iou_hat_np = (p_out.probs >= threshold).cpu().numpy()
        iou = compute_iou(occ_iou_np, occ_iou_hat_np).mean()
        eval_dict['iou'] = iou

        # Estimate voxel iou
        if voxels_occ is not None:
            voxels_occ = voxels_occ.to(device)
            points_voxels = make_3d_grid((-0.5 + 1 / 64, ) * 3,
                                         (0.5 - 1 / 64, ) * 3, (32, ) * 3)
            points_voxels = points_voxels.expand(batch_size,
                                                 *points_voxels.size())
            points_voxels = points_voxels.to(device)
            with torch.no_grad():
                p_out = self.model(points_voxels,
                                   inputs,
                                   world_mat,
                                   camera_mat,
                                   sample=self.eval_sample,
                                   **kwargs)

            voxels_occ_np = (voxels_occ >= 0.5).cpu().numpy()
            occ_hat_np = (p_out.probs >= threshold).cpu().numpy()
            iou_voxels = compute_iou(voxels_occ_np, occ_hat_np).mean()

            eval_dict['iou_voxels'] = iou_voxels

        return eval_dict
예제 #13
0
def compose_inputs(data,
                   mode='train',
                   device=None,
                   input_type='depth_pred',
                   use_gt_depth_map=False,
                   depth_map_mix=False,
                   with_img=False,
                   depth_pointcloud_transfer=None,
                   local=False):

    assert mode in ('train', 'val', 'test')
    raw_data = {}
    if input_type == 'depth_pred':
        gt_mask = data.get('inputs.mask').to(device).byte()
        raw_data['mask'] = gt_mask

        batch_size = gt_mask.size(0)
        if use_gt_depth_map:
            gt_depth_maps = data.get('inputs.depth').to(device)
            background_setting(gt_depth_maps, gt_mask)
            encoder_inputs = gt_depth_maps
            raw_data['depth'] = gt_depth_maps
        else:
            pr_depth_maps = data.get('inputs.depth_pred').to(device)
            background_setting(pr_depth_maps, gt_mask)
            raw_data['depth_pred'] = pr_depth_maps
            if depth_map_mix and mode == 'train':
                gt_depth_maps = data.get('inputs.depth').to(device)
                background_setting(gt_depth_maps, gt_mask)
                raw_data['depth'] = gt_depth_maps

                alpha = torch.rand(batch_size, 1, 1, 1).to(device)
                pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 -
                                                                         alpha)
            encoder_inputs = pr_depth_maps

        if with_img:
            img = data.get('inputs').to(device)
            raw_data[None] = img
            encoder_inputs = {'img': img, 'depth': encoder_inputs}

        if local:
            camera_args = get_camera_args(data,
                                          'points.loc',
                                          'points.scale',
                                          device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']
            encoder_inputs = {
                None: encoder_inputs,
                'world_mat': Rt,
                'camera_mat': K,
            }
            raw_data['world_mat'] = Rt
            raw_data['camera_mat'] = K

        return encoder_inputs, raw_data
    elif input_type == 'depth_pointcloud':
        encoder_inputs = data.get('inputs.depth_pointcloud').to(device)

        if depth_pointcloud_transfer is not None:
            if depth_pointcloud_transfer in ('world', 'world_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]
                world_mat = get_world_mat(data, transpose=None, device=device)
                raw_data['world_mat'] = world_mat

                R = world_mat[:, :, :3]
                # R's inverse is R^T
                encoder_inputs = transform_points(encoder_inputs,
                                                  R.transpose(1, 2))
                # or encoder_inputs = transform_points_back(encoder_inputs, R)

                if depth_pointcloud_transfer == 'world_scale_model':
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            elif depth_pointcloud_transfer in ('view', 'view_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]

                if depth_pointcloud_transfer == 'view_scale_model':
                    world_mat = get_world_mat(data,
                                              transpose=None,
                                              device=device)
                    raw_data['world_mat'] = world_mat
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            else:
                raise NotImplementedError

        raw_data['depth_pointcloud'] = encoder_inputs
        if local:
            #assert depth_pointcloud_transfer.startswith('world')
            encoder_inputs = {None: encoder_inputs}

        return encoder_inputs, raw_data
    else:
        raise NotImplementedError