예제 #1
0
    def loss(self, outputs, targets, weights, pad_mask, weight, log_sigma):
        predictions = outputs.tree.bf.images
        gt_match_dists = outputs.gt_match_dists

        # Compute likelihood
        loss_val = batch_cdist(predictions, targets, reduction='sum')

        log_sigmas = log_sigma - WeightsHacker.hack_weights(
            torch.ones_like(loss_val)).log()
        n = np.prod(predictions.shape[2:])
        loss_val = 0.5 * loss_val * torch.pow(torch.exp(
            -log_sigmas), 2) + n * (log_sigmas + 0.5 * np.log(2 * np.pi))

        # Weigh by matching probability
        match_weights = gt_match_dists
        match_weights = match_weights * pad_mask[:,
                                                 None]  # Note, this is now unnecessary since both tree models handle it already
        loss_val = loss_val * match_weights * weights

        losses = AttrDict()
        losses.dense_img_rec = PenaltyLoss(weight,
                                           breakdown=2)(loss_val,
                                                        log_error_arr=True,
                                                        reduction=[-1, -2])

        # if self._hp.top_bias > 0.0:
        #     losses.n_top_bias_nodes = PenaltyLoss(
        #         self._hp.supervise_match_weight)(1 - WeightsHacker.get_n_top_bias_nodes(targets, weights))

        return losses
예제 #2
0
    def nll(self, estimates, targets, weights=1, log_error_arr=False):
        """
        
        :param estimates: a distribution object
        """
        losses = AttrDict()

        criterion = NLL(self._hp.dense_img_rec_weight, breakdown=1)
        avg_inds = get_dim_inds(targets)[1:]
        losses.dense_img_rec = criterion(estimates,
                                         targets,
                                         weights=weights,
                                         reduction=avg_inds,
                                         log_error_arr=log_error_arr)

        return losses