def loss(self, outputs, targets, weights, pad_mask, weight, log_sigma): predictions = outputs.tree.bf.images gt_match_dists = outputs.gt_match_dists # Compute likelihood loss_val = batch_cdist(predictions, targets, reduction='sum') log_sigmas = log_sigma - WeightsHacker.hack_weights( torch.ones_like(loss_val)).log() n = np.prod(predictions.shape[2:]) loss_val = 0.5 * loss_val * torch.pow(torch.exp( -log_sigmas), 2) + n * (log_sigmas + 0.5 * np.log(2 * np.pi)) # Weigh by matching probability match_weights = gt_match_dists match_weights = match_weights * pad_mask[:, None] # Note, this is now unnecessary since both tree models handle it already loss_val = loss_val * match_weights * weights losses = AttrDict() losses.dense_img_rec = PenaltyLoss(weight, breakdown=2)(loss_val, log_error_arr=True, reduction=[-1, -2]) # if self._hp.top_bias > 0.0: # losses.n_top_bias_nodes = PenaltyLoss( # self._hp.supervise_match_weight)(1 - WeightsHacker.get_n_top_bias_nodes(targets, weights)) return losses
def nll(self, estimates, targets, weights=1, log_error_arr=False): """ :param estimates: a distribution object """ losses = AttrDict() criterion = NLL(self._hp.dense_img_rec_weight, breakdown=1) avg_inds = get_dim_inds(targets)[1:] losses.dense_img_rec = criterion(estimates, targets, weights=weights, reduction=avg_inds, log_error_arr=log_error_arr) return losses