Exemplo n.º 1
0
def depth_to_L(pr_depth_map, gt_mask):
    #not inplace function
    pr_depth_map_max = torch.max(pr_depth_map[gt_mask])
    pr_depth_map_min = torch.min(pr_depth_map[gt_mask])
    background_setting(pr_depth_map, gt_mask, pr_depth_map_max)
    pr_depth_map = (pr_depth_map - pr_depth_map_min) / (pr_depth_map_max -
                                                        pr_depth_map_min)
    return pr_depth_map
Exemplo n.º 2
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        batch_size = p.size(0)
        occ = data.get('points.occ').to(device)

        inputs = data.get('inputs').to(device)
        gt_mask = data.get('inputs.mask').to(device).byte()

        if self.training_detach:
            with torch.no_grad():
                pr_depth_maps = self.model.predict_depth_map(inputs)
        else:
            pr_depth_maps = self.model.predict_depth_map(inputs)

        background_setting(pr_depth_maps, gt_mask)
        if self.depth_map_mix:
            gt_depth_maps = data.get('inputs.depth').to(device)
            background_setting(gt_depth_maps, gt_mask)
            alpha = torch.rand(batch_size, 1, 1, 1).to(device)
            pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 -
                                                                     alpha)

        kwargs = {}
        c = self.model.encode(pr_depth_maps)
        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        p_r = self.model.decode(p, z, c, **kwargs)
        logits = p_r.logits
        probs = p_r.probs

        # loss
        loss_i = get_occ_loss(logits, occ, self.loss_type)
        # loss strategies
        loss_i = occ_loss_postprocess(loss_i, occ, probs,
                                      self.loss_tolerance_episolon,
                                      self.sign_lambda, self.threshold,
                                      self.surface_loss_weight)

        loss = loss + loss_i.sum(-1).mean()

        return loss
Exemplo n.º 3
0
    def forward(self, data, device):
        gt_depth_maps = data.get('inputs.depth').to(device)
        gt_mask = data.get('inputs.mask').to(device).byte()
        background_setting(gt_depth_maps, gt_mask)
        encoder_inputs = gt_depth_maps

        if self.with_img:
            img = data.get('inputs').to(device)
            encoder_inputs = {'img': img, 'depth': encoder_inputs}

        out = self.features(encoder_inputs)
        out = self.pred_fc(out)
        return out
Exemplo n.º 4
0
    def visualize(self, data):
        ''' Performs a visualization step for the data.

        Args:
            data (dict): data dictionary
        '''
        device = self.device

        batch_size = data['points'].size(0)
        inputs = data.get('inputs').to(device)
        #gt_depth_maps = data.get('inputs.depth').to(device)
        gt_mask = data.get('inputs.mask').to(device).byte()

        shape = (32, 32, 32)
        p = make_3d_grid([-0.5] * 3, [0.5] * 3, shape).to(device)
        p = p.expand(batch_size, *p.size())

        kwargs = {}
        with torch.no_grad():
            pr_depth_maps = self.model.predict_depth_map(inputs)
            background_setting(pr_depth_maps, gt_mask)
            p_r = self.model.forward_halfway(p,
                                             pr_depth_maps,
                                             sample=self.eval_sample,
                                             **kwargs)

        occ_hat = p_r.probs.view(batch_size, *shape)
        voxels_out = (occ_hat >= self.threshold).cpu().numpy()

        for i in trange(batch_size):
            input_img_path = os.path.join(self.vis_dir, '%03d_in.png' % i)
            vis.visualize_data(inputs[i].cpu(), 'img', input_img_path)
            vis.visualize_voxels(voxels_out[i],
                                 os.path.join(self.vis_dir, '%03d.png' % i))

            depth_map_path = os.path.join(self.vis_dir,
                                          '%03d_pr_depth.png' % i)
            depth_map = pr_depth_maps[i].cpu()
            depth_map = depth_to_L(depth_map, gt_mask[i].cpu())
            vis.visualize_data(depth_map, 'img', depth_map_path)
Exemplo n.º 5
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer,
        )

        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)

        if self.model.encoder_world_mat is not None:
            out = self.model(encoder_inputs, world_mat=world_mat)
        else:
            out = self.model(encoder_inputs)

        loss = 0
        if isinstance(out, tuple):
            out, trans_feat = out

            if isinstance(self.model.encoder, PointNetEncoder) or isinstance(
                    self.model.encoder, PointNetResEncoder):
                loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat)

        # chamfer distance loss
        if self.loss_type == 'cd':
            dist1, dist2 = cd.chamfer_distance(out, gt_pc)
            loss = (dist1.mean(1) + dist2.mean(1)).mean() / 2.
        else:
            out_pts_count = out.size(1)
            loss = loss + (emd.earth_mover_distance(
                out, gt_pc, transpose=False) / out_pts_count).mean()

        # view penalty loss
        if self.view_penalty:
            gt_mask_flow = data.get('inputs.mask_flow').to(
                device)  # B * 1 * H * W
            if world_mat is None:
                world_mat = get_world_mat(data, device=device)
            camera_mat = get_camera_mat(data, device=device)

            # projection use world mat & camera mat
            if self.gt_pointcloud_transfer == 'world_scale_model':
                out_pts = transform_points(out, world_mat)
            elif self.gt_pointcloud_transfer == 'view_scale_model':
                t = world_mat[:, :, 3:]
                out_pts = out_pts + t
            elif self.gt_pointcloud_transfer == 'view':
                t = world_mat[:, :, 3:]
                out_pts = out_pts * t[:, 2:, :]
                out_pts = out_pts + t
            else:
                raise NotImplementedError

            out_pts_img = project_to_camera(out_pts, camera_mat)
            out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

            out_mask_flow = F.grid_sample(gt_mask_flow,
                                          out_pts_img)  # B * 1 * 1 * n_pts
            loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow,
                                    inplace=True).mean()
            loss = loss + self.loss_mask_flow_ratio * loss_mask_flow

            if self.view_penalty == 'mask_flow_and_depth':
                # depth test loss
                t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1)
                gt_mask = data.get('inputs.mask').byte().to(device)
                depth_pred = data.get('inputs.depth_pred').to(
                    device) * t_scale  # absolute depth from view
                background_setting(depth_pred, gt_mask)
                depth_z = out_pts[:, :, 2:].transpose(1, 2)
                corresponding_z = F.grid_sample(
                    depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                corresponding_z = corresponding_z.squeeze(1)

                # eps
                loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                         corresponding_z,
                                         inplace=True).mean()
                loss = loss + self.loss_depth_test_ratio * loss_depth_test

        return loss
Exemplo n.º 6
0
    def eval_step(self, data):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer)
        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)
        batch_size = gt_pc.size(0)

        with torch.no_grad():
            if self.model.encoder_world_mat is not None:
                out = self.model(encoder_inputs, world_mat=world_mat)
            else:
                out = self.model(encoder_inputs)

            if isinstance(out, tuple):
                out, trans_feat = out

            eval_dict = {}
            if batch_size == 1:
                pointcloud_hat = out.cpu().squeeze(0).numpy()
                pointcloud_gt = gt_pc.cpu().squeeze(0).numpy()

                eval_dict = self.mesh_evaluator.eval_pointcloud(
                    pointcloud_hat, pointcloud_gt)

            # chamfer distance loss
            if self.loss_type == 'cd':
                dist1, dist2 = cd.chamfer_distance(out, gt_pc)
                loss = (dist1.mean(1) + dist2.mean(1)) / 2.
            else:
                loss = emd.earth_mover_distance(out, gt_pc, transpose=False)

            if self.gt_pointcloud_transfer in ('world_scale_model',
                                               'view_scale_model', 'view'):
                pointcloud_scale = data.get('pointcloud.scale').to(
                    device).view(batch_size, 1, 1)
                loss = loss / (pointcloud_scale**2)
                if self.gt_pointcloud_transfer == 'view':
                    if world_mat is None:
                        world_mat = get_world_mat(data, device=device)
                    t_scale = world_mat[:, 2:, 3:]
                    loss = loss * (t_scale**2)

            if self.loss_type == 'cd':
                loss = loss.mean()
                eval_dict['chamfer'] = loss.item()
            else:
                out_pts_count = out.size(1)
                loss = (loss / out_pts_count).mean()
                eval_dict['emd'] = loss.item()

            # view penalty loss
            if self.view_penalty:
                gt_mask_flow = data.get('inputs.mask_flow').to(
                    device)  # B * 1 * H * W
                if world_mat is None:
                    world_mat = get_world_mat(data, device=device)
                camera_mat = get_camera_mat(data, device=device)

                # projection use world mat & camera mat
                if self.gt_pointcloud_transfer == 'world_scale_model':
                    out_pts = transform_points(out, world_mat)
                elif self.gt_pointcloud_transfer == 'view_scale_model':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts + t
                elif self.gt_pointcloud_transfer == 'view':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts * t[:, 2:, :]
                    out_pts = out_pts + t
                else:
                    raise NotImplementedError

                out_pts_img = project_to_camera(out_pts, camera_mat)
                out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

                out_mask_flow = F.grid_sample(gt_mask_flow,
                                              out_pts_img)  # B * 1 * 1 * n_pts
                loss_mask_flow = F.relu(1. - self.mask_flow_eps -
                                        out_mask_flow,
                                        inplace=True).mean()
                loss = self.loss_mask_flow_ratio * loss_mask_flow
                eval_dict['loss_mask_flow'] = loss.item()

                if self.view_penalty == 'mask_flow_and_depth':
                    # depth test loss
                    t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1,
                                                      1)
                    gt_mask = data.get('inputs.mask').byte().to(device)
                    depth_pred = data.get('inputs.depth_pred').to(
                        device) * t_scale

                    background_setting(depth_pred, gt_mask)
                    depth_z = out_pts[:, :, 2:].transpose(1, 2)
                    corresponding_z = F.grid_sample(
                        depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                    corresponding_z = corresponding_z.squeeze(1)

                    # eps = 0.05
                    loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                             corresponding_z,
                                             inplace=True).mean()
                    loss = self.loss_depth_test_ratio * loss_depth_test
                    eval_dict['loss_depth_test'] = loss.item()

        return eval_dict
Exemplo n.º 7
0
def compose_inputs(data,
                   mode='train',
                   device=None,
                   input_type='depth_pred',
                   use_gt_depth_map=False,
                   depth_map_mix=False,
                   with_img=False,
                   depth_pointcloud_transfer=None,
                   local=False):

    assert mode in ('train', 'val', 'test')
    raw_data = {}
    if input_type == 'depth_pred':
        gt_mask = data.get('inputs.mask').to(device).byte()
        raw_data['mask'] = gt_mask

        batch_size = gt_mask.size(0)
        if use_gt_depth_map:
            gt_depth_maps = data.get('inputs.depth').to(device)
            background_setting(gt_depth_maps, gt_mask)
            encoder_inputs = gt_depth_maps
            raw_data['depth'] = gt_depth_maps
        else:
            pr_depth_maps = data.get('inputs.depth_pred').to(device)
            background_setting(pr_depth_maps, gt_mask)
            raw_data['depth_pred'] = pr_depth_maps
            if depth_map_mix and mode == 'train':
                gt_depth_maps = data.get('inputs.depth').to(device)
                background_setting(gt_depth_maps, gt_mask)
                raw_data['depth'] = gt_depth_maps

                alpha = torch.rand(batch_size, 1, 1, 1).to(device)
                pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 -
                                                                         alpha)
            encoder_inputs = pr_depth_maps

        if with_img:
            img = data.get('inputs').to(device)
            raw_data[None] = img
            encoder_inputs = {'img': img, 'depth': encoder_inputs}

        if local:
            camera_args = get_camera_args(data,
                                          'points.loc',
                                          'points.scale',
                                          device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']
            encoder_inputs = {
                None: encoder_inputs,
                'world_mat': Rt,
                'camera_mat': K,
            }
            raw_data['world_mat'] = Rt
            raw_data['camera_mat'] = K

        return encoder_inputs, raw_data
    elif input_type == 'depth_pointcloud':
        encoder_inputs = data.get('inputs.depth_pointcloud').to(device)

        if depth_pointcloud_transfer is not None:
            if depth_pointcloud_transfer in ('world', 'world_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]
                world_mat = get_world_mat(data, transpose=None, device=device)
                raw_data['world_mat'] = world_mat

                R = world_mat[:, :, :3]
                # R's inverse is R^T
                encoder_inputs = transform_points(encoder_inputs,
                                                  R.transpose(1, 2))
                # or encoder_inputs = transform_points_back(encoder_inputs, R)

                if depth_pointcloud_transfer == 'world_scale_model':
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            elif depth_pointcloud_transfer in ('view', 'view_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]

                if depth_pointcloud_transfer == 'view_scale_model':
                    world_mat = get_world_mat(data,
                                              transpose=None,
                                              device=device)
                    raw_data['world_mat'] = world_mat
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            else:
                raise NotImplementedError

        raw_data['depth_pointcloud'] = encoder_inputs
        if local:
            #assert depth_pointcloud_transfer.startswith('world')
            encoder_inputs = {None: encoder_inputs}

        return encoder_inputs, raw_data
    else:
        raise NotImplementedError
Exemplo n.º 8
0
    def generate_mesh(self, data, return_stats=True):
        ''' Generates the output mesh.

        Args:
            data (tensor): data tensor
            return_stats (bool): whether stats should be returned
        '''
        self.model.eval()
        device = self.device
        stats_dict = {}

        if self.input_type in ('depth_pred', 'depth_pointcloud'):
            encoder_inputs, _ = compose_inputs(
                data,
                mode='test',
                device=self.device,
                input_type=self.input_type,
                use_gt_depth_map=self.use_gt_depth_map,
                depth_map_mix=False,
                with_img=self.with_img,
                depth_pointcloud_transfer=self.depth_pointcloud_transfer,
                local=self.local)
        else:
            # Preprocess if requires
            inputs = data.get('inputs').to(device)
            gt_mask = data.get('inputs.mask').to(device).byte()
            if self.preprocessor is not None:
                t0 = time.time()
                with torch.no_grad():
                    inputs = self.preprocessor(inputs)
                stats_dict['time (preprocess)'] = time.time() - t0

            t0 = time.time()
            with torch.no_grad():
                depth = self.model.predict_depth_map(inputs)
            stats_dict['time (predict depth map)'] = time.time() - t0
            background_setting(depth, gt_mask)
            encoder_inputs = depth

        kwargs = {}
        # Encode inputs
        t0 = time.time()
        with torch.no_grad():
            if self.local:
                c = self.model.encoder.forward_local_first_step(encoder_inputs)
            else:
                c = self.model.encode(encoder_inputs)
        stats_dict['time (encode)'] = time.time() - t0

        z = self.model.get_z_from_prior((1, ), sample=self.sample).to(device)

        if self.local:
            mesh = self.generate_from_latent(z,
                                             c,
                                             data=encoder_inputs,
                                             stats_dict=stats_dict,
                                             **kwargs)
        else:
            mesh = self.generate_from_latent(z,
                                             c,
                                             stats_dict=stats_dict,
                                             **kwargs)

        if return_stats:
            return mesh, stats_dict
        else:
            return mesh