예제 #1
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)

        kwargs = {}

        c = self.model.encode_inputs(inputs)
        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        p_r = self.model.decode(p, z, c, **kwargs)
        logits = p_r.logits
        probs = p_r.probs

        # loss
        loss_i = get_occ_loss(logits, occ, self.loss_type)
        # loss strategies
        loss_i = occ_loss_postprocess(loss_i, occ, probs, self.loss_tolerance_episolon, self.sign_lambda, self.threshold, self.surface_loss_weight)

        loss = loss + loss_i.sum(-1).mean()

        return loss
예제 #2
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        batch_size = p.size(0)
        occ = data.get('points.occ').to(device)

        encoder_inputs, _ = compose_inputs(
            data,
            mode='train',
            device=self.device,
            input_type=self.input_type,
            use_gt_depth_map=self.use_gt_depth_map,
            depth_map_mix=self.depth_map_mix,
            with_img=self.with_img,
            depth_pointcloud_transfer=self.depth_pointcloud_transfer,
            local=self.local)

        kwargs = {}
        c = self.model.encode(encoder_inputs, only_feature=False, p=p)
        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # additional loss for feature transform
        if isinstance(c, tuple):
            trans_feature = c[-1]
            if self.local:
                c = (c[0], c[1])
            else:
                c = c[0]

            if isinstance(self.model.encoder, PointNetEncoder) or isinstance(
                    self.model.encoder, PointNetResEncoder):
                loss = loss + 0.001 * feature_transform_reguliarzer(
                    trans_feature)

        # General points
        p_r = self.model.decode(p, z, c, **kwargs)
        logits = p_r.logits
        probs = p_r.probs

        # loss
        loss_i = get_occ_loss(logits, occ, self.loss_type)
        # loss strategies
        loss_i = occ_loss_postprocess(loss_i, occ, probs,
                                      self.loss_tolerance_episolon,
                                      self.sign_lambda, self.threshold,
                                      self.surface_loss_weight)

        loss = loss + loss_i.sum(-1).mean()

        return loss
예제 #3
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        batch_size = p.size(0)
        occ = data.get('points.occ').to(device)

        inputs = data.get('inputs').to(device)
        gt_mask = data.get('inputs.mask').to(device).byte()

        if self.training_detach:
            with torch.no_grad():
                pr_depth_maps = self.model.predict_depth_map(inputs)
        else:
            pr_depth_maps = self.model.predict_depth_map(inputs)

        background_setting(pr_depth_maps, gt_mask)
        if self.depth_map_mix:
            gt_depth_maps = data.get('inputs.depth').to(device)
            background_setting(gt_depth_maps, gt_mask)
            alpha = torch.rand(batch_size, 1, 1, 1).to(device)
            pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 -
                                                                     alpha)

        kwargs = {}
        c = self.model.encode(pr_depth_maps)
        q_z = self.model.infer_z(p, occ, c, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        p_r = self.model.decode(p, z, c, **kwargs)
        logits = p_r.logits
        probs = p_r.probs

        # loss
        loss_i = get_occ_loss(logits, occ, self.loss_type)
        # loss strategies
        loss_i = occ_loss_postprocess(loss_i, occ, probs,
                                      self.loss_tolerance_episolon,
                                      self.sign_lambda, self.threshold,
                                      self.surface_loss_weight)

        loss = loss + loss_i.sum(-1).mean()

        return loss
예제 #4
0
    def compute_loss(self, data):
        ''' Computes the loss.

        Args:
            data (dict): data dictionary
        '''
        device = self.device
        p = data.get('points').to(device)
        if self.binary_occ:
            occ = (data.get('points.occ') >= 0.5).float().to(device)
        else:
            occ = data.get('points.occ').to(device)
        inputs = data.get('inputs', torch.empty(p.size(0), 0)).to(device)
        kwargs = {}

        if self.use_local_feature:
            camera_args = get_camera_args(data,
                                          'points.loc',
                                          'points.scale',
                                          device=device)
            Rt = camera_args['Rt']
            K = camera_args['K']
            f3, f2, f1 = self.model.encode_inputs(inputs, p, Rt, K)
        else:
            f3, f2, f1 = self.model.encode_inputs(inputs)

        q_z = self.model.infer_z(p, occ, f3, **kwargs)
        z = q_z.rsample()

        # KL-divergence
        kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1)
        loss = kl.mean()

        # General points
        p_r = self.model.decode(p, z, f3, f2, f1, **kwargs)
        logits = p_r.logits
        probs = p_r.probs

        # loss
        loss_i = get_occ_loss(logits, occ, self.loss_type)
        # loss strategies
        loss_i = occ_loss_postprocess(loss_i, occ, probs,
                                      self.loss_tolerance_episolon,
                                      self.sign_lambda, self.threshold,
                                      self.surface_loss_weight)

        loss = loss + loss_i.sum(-1).mean()

        return loss