Ejemplo n.º 1
0
    def visualize(self, data):
        ''' Performs a visualization step for the data.

        Args:
            data (dict): data dictionary
        '''
        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer,
        )
        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)
        batch_size = gt_pc.size(0)

        self.model.eval()
        with torch.no_grad():
            if self.model.encoder_world_mat is not None:
                out = self.model(encoder_inputs, world_mat=world_mat)
            else:
                out = self.model(encoder_inputs)

        if isinstance(out, tuple):
            out, _ = out

        for i in trange(batch_size):
            pc = gt_pc[i].cpu()
            vis.visualize_pointcloud(pc,
                                     out_file=os.path.join(
                                         self.vis_dir, '%03d_gt_pc.png' % i))

            pc = out[i].cpu()
            vis.visualize_pointcloud(pc,
                                     out_file=os.path.join(
                                         self.vis_dir, '%03d_pr_pc.png' % i))

            pc = encoder_inputs[i].cpu()
            vis.visualize_pointcloud(pc,
                                     out_file=os.path.join(
                                         self.vis_dir,
                                         '%03d_input_half_pc.png' % i))
Ejemplo n.º 2
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer,
        )

        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)

        if self.model.encoder_world_mat is not None:
            out = self.model(encoder_inputs, world_mat=world_mat)
        else:
            out = self.model(encoder_inputs)

        loss = 0
        if isinstance(out, tuple):
            out, trans_feat = out

            if isinstance(self.model.encoder, PointNetEncoder) or isinstance(
                    self.model.encoder, PointNetResEncoder):
                loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat)

        # chamfer distance loss
        if self.loss_type == 'cd':
            dist1, dist2 = cd.chamfer_distance(out, gt_pc)
            loss = (dist1.mean(1) + dist2.mean(1)).mean() / 2.
        else:
            out_pts_count = out.size(1)
            loss = loss + (emd.earth_mover_distance(
                out, gt_pc, transpose=False) / out_pts_count).mean()

        # view penalty loss
        if self.view_penalty:
            gt_mask_flow = data.get('inputs.mask_flow').to(
                device)  # B * 1 * H * W
            if world_mat is None:
                world_mat = get_world_mat(data, device=device)
            camera_mat = get_camera_mat(data, device=device)

            # projection use world mat & camera mat
            if self.gt_pointcloud_transfer == 'world_scale_model':
                out_pts = transform_points(out, world_mat)
            elif self.gt_pointcloud_transfer == 'view_scale_model':
                t = world_mat[:, :, 3:]
                out_pts = out_pts + t
            elif self.gt_pointcloud_transfer == 'view':
                t = world_mat[:, :, 3:]
                out_pts = out_pts * t[:, 2:, :]
                out_pts = out_pts + t
            else:
                raise NotImplementedError

            out_pts_img = project_to_camera(out_pts, camera_mat)
            out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

            out_mask_flow = F.grid_sample(gt_mask_flow,
                                          out_pts_img)  # B * 1 * 1 * n_pts
            loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow,
                                    inplace=True).mean()
            loss = loss + self.loss_mask_flow_ratio * loss_mask_flow

            if self.view_penalty == 'mask_flow_and_depth':
                # depth test loss
                t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1)
                gt_mask = data.get('inputs.mask').byte().to(device)
                depth_pred = data.get('inputs.depth_pred').to(
                    device) * t_scale  # absolute depth from view
                background_setting(depth_pred, gt_mask)
                depth_z = out_pts[:, :, 2:].transpose(1, 2)
                corresponding_z = F.grid_sample(
                    depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                corresponding_z = corresponding_z.squeeze(1)

                # eps
                loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                         corresponding_z,
                                         inplace=True).mean()
                loss = loss + self.loss_depth_test_ratio * loss_depth_test

        return loss
Ejemplo n.º 3
0
    def eval_step(self, data):
        ''' Performs an evaluation step.

        Args:
            data (dict): data dictionary
        '''
        self.model.eval()

        device = self.device

        encoder_inputs, raw_data = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer)
        world_mat = None
        if (self.model.encoder_world_mat is not None) \
            or self.gt_pointcloud_transfer in ('view', 'view_scale_model'):
            if 'world_mat' in raw_data:
                world_mat = raw_data['world_mat']
            else:
                world_mat = get_world_mat(data, device=device)
        gt_pc = compose_pointcloud(data,
                                   device,
                                   self.gt_pointcloud_transfer,
                                   world_mat=world_mat)
        batch_size = gt_pc.size(0)

        with torch.no_grad():
            if self.model.encoder_world_mat is not None:
                out = self.model(encoder_inputs, world_mat=world_mat)
            else:
                out = self.model(encoder_inputs)

            if isinstance(out, tuple):
                out, trans_feat = out

            eval_dict = {}
            if batch_size == 1:
                pointcloud_hat = out.cpu().squeeze(0).numpy()
                pointcloud_gt = gt_pc.cpu().squeeze(0).numpy()

                eval_dict = self.mesh_evaluator.eval_pointcloud(
                    pointcloud_hat, pointcloud_gt)

            # chamfer distance loss
            if self.loss_type == 'cd':
                dist1, dist2 = cd.chamfer_distance(out, gt_pc)
                loss = (dist1.mean(1) + dist2.mean(1)) / 2.
            else:
                loss = emd.earth_mover_distance(out, gt_pc, transpose=False)

            if self.gt_pointcloud_transfer in ('world_scale_model',
                                               'view_scale_model', 'view'):
                pointcloud_scale = data.get('pointcloud.scale').to(
                    device).view(batch_size, 1, 1)
                loss = loss / (pointcloud_scale**2)
                if self.gt_pointcloud_transfer == 'view':
                    if world_mat is None:
                        world_mat = get_world_mat(data, device=device)
                    t_scale = world_mat[:, 2:, 3:]
                    loss = loss * (t_scale**2)

            if self.loss_type == 'cd':
                loss = loss.mean()
                eval_dict['chamfer'] = loss.item()
            else:
                out_pts_count = out.size(1)
                loss = (loss / out_pts_count).mean()
                eval_dict['emd'] = loss.item()

            # view penalty loss
            if self.view_penalty:
                gt_mask_flow = data.get('inputs.mask_flow').to(
                    device)  # B * 1 * H * W
                if world_mat is None:
                    world_mat = get_world_mat(data, device=device)
                camera_mat = get_camera_mat(data, device=device)

                # projection use world mat & camera mat
                if self.gt_pointcloud_transfer == 'world_scale_model':
                    out_pts = transform_points(out, world_mat)
                elif self.gt_pointcloud_transfer == 'view_scale_model':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts + t
                elif self.gt_pointcloud_transfer == 'view':
                    t = world_mat[:, :, 3:]
                    out_pts = out_pts * t[:, 2:, :]
                    out_pts = out_pts + t
                else:
                    raise NotImplementedError

                out_pts_img = project_to_camera(out_pts, camera_mat)
                out_pts_img = out_pts_img.unsqueeze(1)  # B * 1 * n_pts * 2

                out_mask_flow = F.grid_sample(gt_mask_flow,
                                              out_pts_img)  # B * 1 * 1 * n_pts
                loss_mask_flow = F.relu(1. - self.mask_flow_eps -
                                        out_mask_flow,
                                        inplace=True).mean()
                loss = self.loss_mask_flow_ratio * loss_mask_flow
                eval_dict['loss_mask_flow'] = loss.item()

                if self.view_penalty == 'mask_flow_and_depth':
                    # depth test loss
                    t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1,
                                                      1)
                    gt_mask = data.get('inputs.mask').byte().to(device)
                    depth_pred = data.get('inputs.depth_pred').to(
                        device) * t_scale

                    background_setting(depth_pred, gt_mask)
                    depth_z = out_pts[:, :, 2:].transpose(1, 2)
                    corresponding_z = F.grid_sample(
                        depth_pred, out_pts_img)  # B * 1 * 1 * n_pts
                    corresponding_z = corresponding_z.squeeze(1)

                    # eps = 0.05
                    loss_depth_test = F.relu(depth_z - self.depth_test_eps -
                                             corresponding_z,
                                             inplace=True).mean()
                    loss = self.loss_depth_test_ratio * loss_depth_test
                    eval_dict['loss_depth_test'] = loss.item()

        return eval_dict
Ejemplo n.º 4
0
def compose_inputs(data,
                   mode='train',
                   device=None,
                   input_type='depth_pred',
                   use_gt_depth_map=False,
                   depth_map_mix=False,
                   with_img=False,
                   depth_pointcloud_transfer=None,
                   local=False):

    assert mode in ('train', 'val', 'test')
    raw_data = {}
    if input_type == 'depth_pred':
        gt_mask = data.get('inputs.mask').to(device).byte()
        raw_data['mask'] = gt_mask

        batch_size = gt_mask.size(0)
        if use_gt_depth_map:
            gt_depth_maps = data.get('inputs.depth').to(device)
            background_setting(gt_depth_maps, gt_mask)
            encoder_inputs = gt_depth_maps
            raw_data['depth'] = gt_depth_maps
        else:
            pr_depth_maps = data.get('inputs.depth_pred').to(device)
            background_setting(pr_depth_maps, gt_mask)
            raw_data['depth_pred'] = pr_depth_maps
            if depth_map_mix and mode == 'train':
                gt_depth_maps = data.get('inputs.depth').to(device)
                background_setting(gt_depth_maps, gt_mask)
                raw_data['depth'] = gt_depth_maps

                alpha = torch.rand(batch_size, 1, 1, 1).to(device)
                pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 -
                                                                         alpha)
            encoder_inputs = pr_depth_maps

        if with_img:
            img = data.get('inputs').to(device)
            raw_data[None] = img
            encoder_inputs = {'img': img, 'depth': encoder_inputs}

        if local:
            camera_args = get_camera_args(data,
                                          'points.loc',
                                          'points.scale',
                                          device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']
            encoder_inputs = {
                None: encoder_inputs,
                'world_mat': Rt,
                'camera_mat': K,
            }
            raw_data['world_mat'] = Rt
            raw_data['camera_mat'] = K

        return encoder_inputs, raw_data
    elif input_type == 'depth_pointcloud':
        encoder_inputs = data.get('inputs.depth_pointcloud').to(device)

        if depth_pointcloud_transfer is not None:
            if depth_pointcloud_transfer in ('world', 'world_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]
                world_mat = get_world_mat(data, transpose=None, device=device)
                raw_data['world_mat'] = world_mat

                R = world_mat[:, :, :3]
                # R's inverse is R^T
                encoder_inputs = transform_points(encoder_inputs,
                                                  R.transpose(1, 2))
                # or encoder_inputs = transform_points_back(encoder_inputs, R)

                if depth_pointcloud_transfer == 'world_scale_model':
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            elif depth_pointcloud_transfer in ('view', 'view_scale_model'):
                encoder_inputs = encoder_inputs[:, :, [1, 0, 2]]

                if depth_pointcloud_transfer == 'view_scale_model':
                    world_mat = get_world_mat(data,
                                              transpose=None,
                                              device=device)
                    raw_data['world_mat'] = world_mat
                    t = world_mat[:, :, 3:]
                    encoder_inputs = encoder_inputs * t[:, 2:, :]
            else:
                raise NotImplementedError

        raw_data['depth_pointcloud'] = encoder_inputs
        if local:
            #assert depth_pointcloud_transfer.startswith('world')
            encoder_inputs = {None: encoder_inputs}

        return encoder_inputs, raw_data
    else:
        raise NotImplementedError