Пример #1
0
 def log_marginal_likelihood_estimate(self, x, num_samples):
     weight = torch.zeros(x.size(0), num_samples)
     use_cuda = torch.cuda.is_available()
     for i in range(num_samples):
         z, recon_x, mu, logvar = self.forward(x)
         zloglikelihood = torch.from_numpy(self.log_prob(z.data.cpu().numpy())).float()
         if use_cuda:
             zloglikelihood = zloglikelihood.cuda()
         if self._dec_act is not None:
             dataloglikelihood = torch.sum(x*torch.log(torch.clamp(recon_x, min=1e-10))+
                 (1-x)*torch.log(torch.clamp(1-recon_x, min=1e-10)), 1)
         else:
             dataloglikelihood = -torch.sum((recon_x-x)**2, 1)
         log_qz = log_likelihood_samplesImean_sigma(z, mu, logvar)
         weight[:, i] = (dataloglikelihood + zloglikelihood - log_qz).data.cpu()
         # pdb.set_trace()
     return log_sum_exp(weight, dim=1) - math.log(num_samples)
 def log_marginal_likelihood_estimate(self, x, num_samples):
     weight = torch.zeros(x.size(0), num_samples)
     for i in range(num_samples):
         z, recon_x, mu, logvar = self.forward(x)
         zloglikelihood = self.latentTree(z)
         if self._dec_act is not None:
             dataloglikelihood = torch.sum(
                 x * torch.log(torch.clamp(recon_x, min=1e-10)) +
                 (1 - x) * torch.log(torch.clamp(1 - recon_x, min=1e-10)),
                 1)
         else:
             dataloglikelihood = -torch.mean(torch.sum((recon_x - x)**2, 1))
         log_qz = log_likelihood_samplesImean_sigma(z, mu, logvar)
         weight[:, i] = (dataloglikelihood + zloglikelihood -
                         log_qz).data.cpu()
     # pdb.set_trace()
     return log_sum_exp(weight, dim=1) - math.log(num_samples)
Пример #3
0
    def forward(self, predictions, targets):
        total_length = self.sample_duration
        loc_pred, conf_pred = predictions
        if not self.extra_layers:
            # loc_pred : [batch_size, 2]
            # conf_pred : [batch_size, 3]
            loc_target = targets[:, :-1].clone().detach().data
            loc_target = encoding(loc_target, total_length)
            conf_target = targets[:, -1].clone().detach().to(torch.long).data

            loss_loc = self.reg_loss(loc_pred, loc_target)
            loss_conf = self.conf_loss(conf_pred, conf_target)

        else:
            # loc_pred : [batch_size, default_bar_num, 2]
            # conf_pred : [batch_size, default_bar_num, 3]
            batch_size = targets.size(0)
            default_bar_num = self.default_bar.size(0)
            self.reg_loss = nn.SmoothL1Loss(reduction='sum')
            self.conf_loss = nn.CrossEntropyLoss(reduction='sum')

            loc_t = torch.Tensor(batch_size, default_bar_num, 2)
            conf_t = torch.LongTensor(batch_size, default_bar_num)
            # wrap targets
            with torch.no_grad():
                self.default_bar = self.default_bar.to(self.device)
                loc_t = loc_t.to(self.device)
                conf_t = conf_t.to(self.device)

            for idx in range(batch_size):
                truths = targets[idx, :-1].view(-1, 2).data  # [1, 2]
                labels = targets[idx, -1].view(-1,
                                               1).to(torch.long).data  # [1, 1]
                default = self.default_bar.clone().detach(
                ).data  # [default_bar_num, 2]
                # jaccard index
                # truths = truths.view(-1, 1, 2)      # [1, 1, 2]
                # labels = labels.view(-1, 1)         # [1, 1]
                # default = default.view(1, -1, 2)    # [1, default_bar_num, 2]
                overlaps = cal_iou(  # [1, default_bar_num]
                    truths, default, use_default=True)
                # (Bipartite Matching)
                # [1(gt_num), 1] best prior for each ground truth
                best_prior_overlap, best_prior_idx = overlaps.max(
                    1, keepdim=True)  # GT와 가장 많이 겹치는 default
                # [1, default_bar_num] best ground truth for each prior
                best_truth_overlap, best_truth_idx = overlaps.max(
                    0, keepdim=True)  # default와 가장 많이 겹치는 GT
                best_prior_idx.squeeze_(1)  # [1(gt_num)]
                best_prior_overlap.squeeze_(1)  # [1(gt_num)]
                best_truth_idx.squeeze_(0)  # [default_bar_num]
                best_truth_overlap.squeeze_(0)  # [default_bar_num]
                best_truth_overlap.index_fill_(0, best_prior_idx,
                                               2)  # ensure best prior
                # TODO refactor: index  best_prior_idx with long tensor
                # ensure every gt matches with its prior of max overlap
                for j in range(best_prior_idx.size(0)):
                    best_truth_idx[best_prior_idx[j]] = j
                matches = truths[best_truth_idx]  # Shape: [default_bar_num, 2]
                conf = labels[best_truth_idx] + 1  # Shape: [default_bar_num]

                background_conf_idx = conf == 1  # get index of background
                if isinstance(self.neg_threshold, tuple):
                    neg_thresh_cut = self.neg_threshold[0]
                    neg_thresh_gradual = self.neg_threshold[1]
                    cut_idx = conf == 2
                    gradual_idx = conf == 3
                    cut_idx[best_truth_overlap >= neg_thresh_cut] = 0
                    gradual_idx[best_truth_overlap >= neg_thresh_gradual] = 0
                    conf[cut_idx] = 0
                    conf[gradual_idx] = 0
                else:
                    conf[best_truth_overlap <
                         self.neg_threshold] = 0  # label as negative
                conf[background_conf_idx] = 1  # set label to background

                assert matches.size() == self.default_bar.size(),\
                    "matches_size : {}, default_bar_size : {}".format(matches.size(), default.size())
                loc = encoding(matches, total_length, default_bar=default)
                loc_t[
                    idx] = loc  # [default_bar_num,2] encoded offsets to learn
                conf_t[idx] = conf.squeeze(
                    1)  # [default_bar_num] top class label for each prior

            pos = conf_t > 0
            loc_pos = conf_t > 1

            # Localization Loss (Smooth L1)
            # Shape: [batch, default_bar_num, 2]
            pos_idx = loc_pos.unsqueeze(loc_pos.dim()).expand_as(loc_pred)
            loc_p = loc_pred[pos_idx].view(-1, 2)
            loc_t = loc_t[pos_idx].view(-1, 2)
            loss_loc = self.reg_loss(loc_p, loc_t)

            # Compute max conf across batch for hard negative mining
            # batch_conf : [batch_size * default_bar_num, num_classes]
            # loss_conf : [batch_size * default_bar_num, 1]
            batch_conf = conf_pred.view(-1, self.num_classes)
            conf_t = torch.clamp(conf_t - 1, min=0)
            loss_conf = log_sum_exp(batch_conf) - batch_conf.gather(
                1, conf_t.view(-1, 1))

            # Hard Negative Mining
            loss_conf = loss_conf.view(batch_size,
                                       -1)  # [batch_size, default_bar_num]
            loss_conf[
                pos] = 0  # filter out pos boxes for now / positive에 해당하는 bar들을 filtering
            _, loss_idx = loss_conf.sort(
                1, descending=True)  # loss_conf가 큰 idx의 내림차순 : [8, 26]
            _, idx_rank = loss_idx.sort(
                1)  # 각 idx의 등수 : [batch_size, default_bar_num]
            num_pos = pos.long().sum(1, keepdim=True)
            num_neg = torch.clamp(self.negpos_ratio * num_pos,
                                  max=pos.size(1) - 1)  # [batch_size, 1]
            neg = idx_rank < num_neg.expand_as(
                idx_rank)  # [batch_size, default_bar_num]

            # Confidence Loss Including Positive and Negative Examples
            pos_idx = pos.unsqueeze(2).expand_as(conf_pred)
            neg_idx = neg.unsqueeze(2).expand_as(conf_pred)
            conf_p = conf_pred[(pos_idx + neg_idx).gt(0)].view(
                -1, self.num_classes)
            targets_weighted = conf_t[(pos + neg).gt(0)]
            loss_conf = self.conf_loss(conf_p, targets_weighted)

            # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N

            N = num_pos.data.sum()
            loss_loc /= N
            loss_conf /= N

            # N_pos = num_pos.data.sum()
            # N_neg = num_neg.data.sum()
            # loss_loc /= N_pos
            # loss_conf /= N_pos + N_neg

        return loss_loc, loss_conf