예제 #1
0
    def compute_loss(self, pred, pos, neg, valid, summ_writer):
        label = pos * 2.0 - 1.0
        a = -label * pred
        b = F.relu(a)
        loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))

        mask_ = (pos + neg > 0.0).float()
        loss_vis = torch.mean(loss * mask_ * valid, dim=3)
        summ_writer.summ_oned('sub/prob_loss', loss_vis)

        # pos_loss = reduce_masked_mean(loss, pos*valid)
        # neg_loss = reduce_masked_mean(loss, neg*valid)
        # balanced_loss = pos_loss + neg_loss

        # pos_occ_loss = utils_basic.reduce_masked_mean(loss, pos*valid*occ)
        # pos_free_loss = utils_basic.reduce_masked_mean(loss, pos*valid*free)
        # neg_occ_loss = utils_basic.reduce_masked_mean(loss, neg*valid*occ)
        # neg_free_loss = utils_basic.reduce_masked_mean(loss, neg*valid*free)
        # balanced_loss = pos_occ_loss + pos_free_loss + neg_occ_loss + neg_free_loss

        pos_loss = utils_basic.reduce_masked_mean(loss, pos * valid)
        neg_loss = utils_basic.reduce_masked_mean(loss, neg * valid)
        balanced_loss = pos_loss + neg_loss

        return balanced_loss
예제 #2
0
    def forward(self, feat0, feat1, valid0, valid1, summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()
        B, C, D, H, W = list(feat0.shape)

        neg_input = torch.cat([feat1-feat0], dim=1)

        feat1_flat = feat1.reshape(B, C, -1)
        valid1_flat = valid1.reshape(B, 1, -1)
        perm = np.random.permutation(D*H*W)
        feat1_flat_shuf = feat1_flat[:,:,perm]
        feat1_shuf = feat1_flat_shuf.reshape(B, C, D, H, W)
        valid1_flat_shuf = valid1_flat[:,:,perm]
        valid1_shuf = valid1_flat_shuf.reshape(B, 1, D, H, W)
        # pos_input = torch.cat([feat0, feat1_shuf-feat0], dim=1)
        pos_input = torch.cat([feat1_shuf-feat0], dim=1)

        # noncorresps should be rejected 
        pos_output = self.net(pos_input)
        # corresps should NOT be rejected
        neg_output = self.net(neg_input)

        pos_sig = F.sigmoid(pos_output)
        neg_sig = F.sigmoid(neg_output)

        if summ_writer is not None:
            summ_writer.summ_feat('reject/pos_input', pos_input, pca=True)
            summ_writer.summ_feat('reject/neg_input', neg_input, pca=True)
            summ_writer.summ_oned('reject/pos_sig', pos_sig, bev=True, norm=False)
            summ_writer.summ_oned('reject/neg_sig', neg_sig, bev=True, norm=False)

        pos_output_vec = pos_output.reshape(B, D*H*W)
        neg_output_vec = neg_output.reshape(B, D*H*W)
        pos_target_vec = torch.ones([B, D*H*W]).float().cuda()
        neg_target_vec = torch.zeros([B, D*H*W]).float().cuda()

        # if feat1_shuf is valid, then it is practically guranateed to mismatch feat0
        pos_valid_vec = valid1_shuf.reshape(B, D*H*W)
        # both have to be valid to not reject
        neg_valid_vec = (valid0*valid1).reshape(B, D*H*W)

        pos_loss_vec = self.criterion(pos_output_vec, pos_target_vec)
        neg_loss_vec = self.criterion(neg_output_vec, neg_target_vec)
        
        pos_loss = utils_basic.reduce_masked_mean(pos_loss_vec, pos_valid_vec)
        neg_loss = utils_basic.reduce_masked_mean(neg_loss_vec, pos_valid_vec)

        ce_loss = pos_loss + neg_loss
        utils_misc.add_loss('reject3D/ce_loss_pos', 0, pos_loss, 0, summ_writer)
        utils_misc.add_loss('reject3D/ce_loss_neg', 0, neg_loss, 0, summ_writer)
        total_loss = utils_misc.add_loss('reject3D/ce_loss', total_loss, ce_loss, hyp.reject3D_ce_coeff, summ_writer)
            
        return total_loss, neg_sig, pos_sig
예제 #3
0
    def compute_loss(self, pred, seg, free, summ_writer):
        # pred is B x C x Z x Y x X
        # seg is B x Z x Y x X

        # ensure the "free" voxels are labelled zero
        seg = seg * (1 - free)

        # seg_bak = seg.clone()

        # note that label0 in seg is invalid, so i need to ignore these
        # but all "free" voxels count as the air class
        # seg = (seg-1).clamp(min=0)
        # seg = (seg-1).clamp(min=0)
        # # ignore_mask = (seg[seg==-1]).float()
        # seg[seg==-1] = 0
        loss = F.cross_entropy(pred, seg, reduction='none')
        # loss is B x Z x Y x X

        loss_any = ((seg > 0).float() + (free > 0).float()).clamp(0, 1)
        loss_vis = torch.mean(loss * loss_any, dim=2).unsqueeze(1)
        summ_writer.summ_oned('seg/prob_loss', loss_vis)

        loss = loss.reshape(-1)
        # seg_bak = seg_bak.reshape(-1)
        seg = seg.reshape(-1)
        free = free.reshape(-1)

        # print('loss', loss.shape)
        # print('seg_bak', seg_bak.shape)

        losses = []
        # total_loss = 0.0

        # next, i want to gather up the loss for each valid class, and balance these into a total
        for cls in list(range(self.num_classes)):
            if cls == 0:
                mask = free.clone()
            else:
                # mask = (seg_bak==cls).float()
                mask = (seg == cls).float()

            # print('mask', mask.shape)
            # print('loss', loss.shape)
            cls_loss = utils_basic.reduce_masked_mean(loss, mask)
            print('cls %d sum' % cls,
                  torch.sum(mask).detach().cpu().numpy(), 'loss',
                  cls_loss.detach().cpu().numpy())

            # print('cls_loss', cls_loss.shape)
            # print('cls %d loss' % cls, cls_loss.detach().cpu().numpy())
            # total_loss = total_loss + cls_loss
            if torch.sum(mask) >= 1:
                losses.append(cls_loss)
            # print('mask', mask.shape)
            # loss_ = loss[seg_bak==cls]
            # print('loss_', loss_.shape)
            # loss_ = loss[seg_bak==cls]
        total_loss = torch.mean(torch.stack(losses))
        return total_loss
예제 #4
0
    def forward(self, feat, occ_g, free_g, summ_writer=None):
        B, C, Z, Y, X = list(feat.shape)
        total_loss = torch.tensor(0.0).cuda()
        
        summ_writer.summ_feat('genoccnet/feat_input', feat, pca=False)
        feat = feat.long() # discrete input

        logit = self.run_net(feat)

        # smooth loss
        dz, dy, dx = gradient3D(logit, absolute=True)
        smooth_vox = torch.mean(dx+dy+dx, dim=1, keepdims=True)
        summ_writer.summ_oned('genoccnet/smooth_loss', torch.mean(smooth_vox, dim=3))
        # smooth_loss = utils_basic.reduce_masked_mean(smooth_vox, valid)
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('genoccnet/smooth_loss', total_loss,
                                         smooth_loss, hyp.genocc_smooth_coeff, summ_writer)
        
        summ_writer.summ_feat('genoccnet/feat_output', logit, pca=False)
        occ_e = torch.argmax(logit, dim=1, keepdim=True)
        
        summ_writer.summ_occ('genoccnet/occ_e', occ_e)
        summ_writer.summ_occ('genoccnet/occ_g', occ_g)

        loss_pos = self.criterion(logit, (occ_g[:,0]).long())
        loss_neg = self.criterion(logit, (1-free_g[:,0]).long())

        summ_writer.summ_oned('genoccnet/loss',
                              torch.mean((loss_pos+loss_neg), dim=1, keepdim=True) * \
                              torch.clamp(occ_g+(1-free_g), 0, 1),
                              bev=True)
        
        loss_pos = utils_basic.reduce_masked_mean(loss_pos.unsqueeze(1), occ_g)
        loss_neg = utils_basic.reduce_masked_mean(loss_neg.unsqueeze(1), free_g)
        loss_bal = loss_pos + loss_neg
        total_loss = utils_misc.add_loss('genoccnet/loss_bal', total_loss, loss_bal, hyp.genocc_coeff, summ_writer)

        # sample = self.generate_sample(1, int(Z/2), int(Y/2), int(X/2))
        # occ_sample = torch.argmax(sample, dim=1, keepdim=True)
        # summ_writer.summ_occ('genoccnet/occ_sample', occ_sample)
        
        return logit, total_loss
    def forward(self, feat, rgb_g, valid, summ_writer, name):
        total_loss = torch.tensor(0.0).cuda()
        if hyp.dataset_name == "clevr":
            valid = torch.ones_like(valid)

        feat = self.net(feat)
        emb_e = self.emb_layer(feat)
        rgb_e = self.rgb_layer(feat)
        # postproc
        emb_e = l2_normalize(emb_e, dim=1)
        rgb_e = torch.nn.functional.tanh(rgb_e) * 0.5

        loss_im = l1_on_axis(rgb_e - rgb_g, 1, keepdim=True)
        summ_writer.summ_oned('view/rgb_loss', loss_im * valid)
        rgb_loss = utils_basic.reduce_masked_mean(loss_im, valid)

        total_loss = utils_misc.add_loss('view/rgb_l1_loss', total_loss,
                                         rgb_loss, hyp.view_l1_coeff,
                                         summ_writer)

        # vis
        summ_writer.summ_rgbs(f'view/{name}', [rgb_e, rgb_g])

        return total_loss, rgb_e, emb_e
예제 #6
0
    def forward(self, image_input, summ_writer=None, is_train=True):
        B, H, W = list(image_input.shape)
        total_loss = torch.tensor(0.0).cuda()
        summ_writer.summ_oned('sigen2d/image_input',
                              image_input.unsqueeze(1) / 512.0,
                              norm=False)

        emb = self.embed(image_input)

        y = torch.randint(low=0, high=H, size=[B, self.num_choices, 1])
        x = torch.randint(low=0, high=W, size=[B, self.num_choices, 1])
        choice_mask = utils_improc.xy2mask(torch.cat([x, y], dim=2),
                                           H,
                                           W,
                                           norm=False)
        summ_writer.summ_oned('sigen2d/choice_mask', choice_mask, norm=False)

        # cover up the 3x3 region surrounding each choice
        xy = torch.cat([
            torch.cat([x - 1, y - 1], dim=2),
            torch.cat([x + 0, y - 1], dim=2),
            torch.cat([x + 1, y - 1], dim=2),
            torch.cat([x - 1, y], dim=2),
            torch.cat([x + 0, y], dim=2),
            torch.cat([x + 1, y], dim=2),
            torch.cat([x - 1, y + 1], dim=2),
            torch.cat([x + 0, y + 1], dim=2),
            torch.cat([x + 1, y + 1], dim=2)
        ],
                       dim=1)
        input_mask = 1.0 - utils_improc.xy2mask(xy, H, W, norm=False)

        # if is_train:
        #     input_mask = (torch.rand((B, 1, H, W)).cuda() > 0.5).float()
        # else:
        #     input_mask = torch.ones((B, 1, H, W)).cuda().float()
        # input_mask = input_mask * (1.0 - choice_mask)
        # input_mask = 1.0 - choice_mask

        emb = emb * input_mask
        summ_writer.summ_oned('sigen2d/input_mask', input_mask, norm=False)

        image_output_logits, _ = self.net(emb, input_mask)

        image_output = torch.argmax(image_output_logits, dim=1, keepdim=True)
        summ_writer.summ_feat('sigen2d/emb', emb, pca=True)
        summ_writer.summ_oned('sigen2d/image_output',
                              image_output.float() / 512.0,
                              norm=False)

        ce_loss_image = F.cross_entropy(image_output_logits,
                                        image_input,
                                        reduction='none').unsqueeze(1)
        # summ_writer.summ_oned('sigen2d/ce_loss', ce_loss_image)
        # ce_loss = torch.mean(ce_loss_image)

        # only apply loss at the choice pixels
        ce_loss = utils_basic.reduce_masked_mean(ce_loss_image, choice_mask)
        total_loss = utils_misc.add_loss('sigen2d/ce_loss', total_loss,
                                         ce_loss, hyp.sigen2d_coeff,
                                         summ_writer)

        return total_loss
예제 #7
0
    def forward(self,
                feats,
                xyzlist_cam,
                scorelist,
                vislist,
                occs,
                summ_writer,
                suffix=''):
        total_loss = torch.tensor(0.0).cuda()
        B, S, C, Z2, Y2, X2 = list(feats.shape)
        B, S, C, Z, Y, X = list(occs.shape)
        B2, S2, D = list(xyzlist_cam.shape)
        assert (B == B2, S == S2)
        assert (D == 3)

        xyzlist_mem = utils_vox.Ref2Mem(xyzlist_cam, Z, Y, X)
        # these are B x S x 3
        scorelist = scorelist.unsqueeze(2)
        # this is B x S x 1
        vislist = vislist[:, 0].reshape(B, 1, 1)
        # we only care that the object was visible in frame0
        scorelist = scorelist * vislist

        if self.use_cost_vols:
            if summ_writer.save_this:
                summ_writer.summ_traj_on_occ('forecast/actual_traj',
                                             xyzlist_mem * scorelist,
                                             torch.max(occs, dim=1)[0],
                                             already_mem=True,
                                             sigma=2)

            Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
            Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
            occ_hint0 = utils_vox.voxelize_xyz(xyzlist_cam[:, 0:1], Z4, Y4, X4)
            occ_hint1 = utils_vox.voxelize_xyz(xyzlist_cam[:, 1:2], Z4, Y4, X4)
            occ_hint0 = occ_hint0 * scorelist[:, 0].reshape(B, 1, 1, 1, 1)
            occ_hint1 = occ_hint1 * scorelist[:, 1].reshape(B, 1, 1, 1, 1)
            occ_hint = torch.cat([occ_hint0, occ_hint1], dim=1)
            occ_hint = F.interpolate(occ_hint, scale_factor=4, mode='nearest')
            # this is B x 1 x Z x Y x X
            summ_writer.summ_occ('forecast/occ_hint',
                                 (occ_hint0 + occ_hint1).clamp(0, 1))

            crops = []
            for s in list(range(S)):
                crop = utils_vox.center_mem_on_xyz(occs_highres[:, s],
                                                   xyzlist_cam[:,
                                                               s], Z2, Y2, X2)
                crops.append(crop)
            crops = torch.stack(crops, dim=0)
            summ_writer.summ_occs('forecast/crops', crops)

            # condition on the occ_hint
            feat = torch.cat([feat, occ_hint], dim=1)

            N = hyp.forecast_num_negs
            sampled_trajs_mem = self.sample_trajs_from_library(N, xyzlist_mem)

            if summ_writer.save_this:
                for n in list(range(np.min([N, 10]))):
                    xyzlist_mem = sampled_trajs_mem[0, n].unsqueeze(0)
                    # this is 1 x S x 3
                    summ_writer.summ_traj_on_occ(
                        'forecast/lib%d_xyzlist' % n,
                        xyzlist_mem,
                        torch.zeros([1, 1, Z, Y, X]).float().cuda(),
                        already_mem=True)

            cost_vols = self.cost_forecaster(feat)
            # cost_vols = F.sigmoid(cost_vols)
            cost_vols = F.interpolate(cost_vols,
                                      scale_factor=2,
                                      mode='trilinear')

            # cost_vols is B x S x Z x Y x X
            summ_writer.summ_histogram('forecast/cost_vols_hist', cost_vols)
            cost_vols = cost_vols.clamp(
                -1000, 1000)  # raquel says this adds stability
            summ_writer.summ_histogram('forecast/cost_vols_clamped_hist',
                                       cost_vols)

            cost_vols_vis = torch.mean(cost_vols, dim=3).unsqueeze(2)
            # cost_vols_vis is B x S x 1 x Z x X
            summ_writer.summ_oneds('forecast/cost_vols_vis',
                                   torch.unbind(cost_vols_vis, dim=1))

            # smooth loss
            cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X)
            dz, dy, dx = gradient3D(cost_vols_, absolute=True)
            dt = torch.abs(cost_vols[:, 1:] - cost_vols[:, 0:-1])
            smooth_vox_spatial = torch.mean(dx + dy + dz, dim=1, keepdims=True)
            smooth_vox_time = torch.mean(dt, dim=1, keepdims=True)
            summ_writer.summ_oned('forecast/smooth_loss_spatial',
                                  torch.mean(smooth_vox_spatial, dim=3))
            summ_writer.summ_oned('forecast/smooth_loss_time',
                                  torch.mean(smooth_vox_time, dim=3))
            smooth_loss = torch.mean(smooth_vox_spatial) + torch.mean(
                smooth_vox_time)
            total_loss = utils_misc.add_loss('forecast/smooth_loss',
                                             total_loss, smooth_loss,
                                             hyp.forecast_smooth_coeff,
                                             summ_writer)

            def clamp_xyz(xyz, X, Y, Z):
                x, y, z = torch.unbind(xyz, dim=-1)
                x = x.clamp(0, X)
                y = x.clamp(0, Y)
                z = x.clamp(0, Z)
                xyz = torch.stack([x, y, z], dim=-1)
                return xyz

            # obj_xyzlist_mem is K x B x S x 3
            # xyzlist_mem is B x S x 3
            # sampled_trajs_mem is B x N x S x 3
            xyz_pos_ = xyzlist_mem.reshape(B * S, 1, 3)
            xyz_neg_ = sampled_trajs_mem.permute(0, 2, 1,
                                                 3).reshape(B * S, N, 3)
            # xyz_pos_ = clamp_xyz(xyz_pos_, X, Y, Z)
            # xyz_neg_ = clamp_xyz(xyz_neg_, X, Y, Z)
            xyz_ = torch.cat([xyz_pos_, xyz_neg_], dim=1)
            xyz_ = clamp_xyz(xyz_, X, Y, Z)
            cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X)
            x, y, z = torch.unbind(xyz_, dim=2)
            # x = x.clamp(0, X)
            # y = x.clamp(0, Y)
            # z = x.clamp(0, Z)
            cost_ = utils_samp.bilinear_sample3D(cost_vols_, x, y,
                                                 z).squeeze(1)
            # cost is B*S x 1+N
            cost_pos = cost_[:, 0:1]  # B*S x 1
            cost_neg = cost_[:, 1:]  # B*S x N

            cost_pos = cost_pos.unsqueeze(2)  # B*S x 1 x 1
            cost_neg = cost_neg.unsqueeze(1)  # B*S x 1 x N

            utils_misc.add_loss('forecast/mean_cost_pos', 0,
                                torch.mean(cost_pos), 0, summ_writer)
            utils_misc.add_loss('forecast/mean_cost_neg', 0,
                                torch.mean(cost_neg), 0, summ_writer)
            utils_misc.add_loss('forecast/mean_margin', 0,
                                torch.mean(cost_neg - cost_pos), 0,
                                summ_writer)

            xyz_pos = xyz_pos_.unsqueeze(2)  # B*S x 1 x 1 x 3
            xyz_neg = xyz_neg_.unsqueeze(1)  # B*S x 1 x N x 3
            dist = torch.norm(xyz_pos - xyz_neg, dim=3)  # B*S x 1 x N
            dist = dist / float(
                Z) * 5.0  # normalize for resolution, but upweight it a bit
            margin = F.relu(cost_pos - cost_neg + dist)
            margin = margin.reshape(B, S, N)
            # mean over time (in the paper this is a sum)
            margin = utils_basic.reduce_masked_mean(margin,
                                                    scorelist.repeat(1, 1, N),
                                                    dim=1)
            # max over the negatives
            maxmargin = torch.max(margin, dim=1)[0]  # B
            maxmargin_loss = torch.mean(maxmargin)
            total_loss = utils_misc.add_loss('forecast/maxmargin_loss',
                                             total_loss, maxmargin_loss,
                                             hyp.forecast_maxmargin_coeff,
                                             summ_writer)

            cost_neg = cost_neg.reshape(B, S, N)[0].detach().cpu().numpy()
            sampled_trajs_mem = sampled_trajs_mem.reshape(B, N, S, 3)[0:1]
            cost_neg = np.reshape(cost_neg, [S, N])
            cost_neg = np.sum(cost_neg, axis=0)
            inds = np.argsort(cost_neg, axis=0)

            for n in list(range(2)):

                xyzlist_e_mem = sampled_trajs_mem[0:1, inds[n]]
                xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X)
                # this is B x S x 3

                # if summ_writer.save_this and n==0:
                #     print('xyzlist_e_cam', xyzlist_e_cam[0:1])
                #     print('xyzlist_g_cam', xyzlist_cam[0:1])
                #     print('scorelist', scorelist[0:1])

                dist = torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2)
                # this is B x S
                meandist = utils_basic.reduce_masked_mean(
                    dist, scorelist[0:1].squeeze(2))
                utils_misc.add_loss('forecast/xyz_dist_%d' % n, 0, meandist, 0,
                                    summ_writer)
                # dist = torch.mean(torch.sum(torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2), dim=1))

                # mpe = torch.mean(torch.norm(xyzlist_cam[0:1,int(S/2)] - xyzlist_e_cam[0:1,int(S/2)], dim=1))
                # mpe = utils_basic.reduce_masked_mean(dist, scorelist[0:1])
                # utils_misc.add_loss('forecast/xyz_mpe_%d' % n, 0, dist, 0, summ_writer)

                # epe = torch.mean(torch.norm(xyzlist_cam[0:1,-1] - xyzlist_e_cam[0:1,-1], dim=1))
                # utils_misc.add_loss('forecast/xyz_epe_%d' % n, 0, dist, 0, summ_writer)

            if summ_writer.save_this:
                # plot the best and worst trajs
                # print('sorted costs:', cost_neg[inds])
                for n in list(range(2)):
                    ind = inds[n]
                    # print('plotting good traj with cost %.2f' % (cost_neg[ind]))
                    xyzlist_e_mem = sampled_trajs_mem[:, ind]
                    # this is 1 x S x 3
                    summ_writer.summ_traj_on_occ(
                        'forecast/best_sampled_traj%d' % n,
                        xyzlist_e_mem,
                        torch.max(occs[0:1], dim=1)[0],
                        # torch.zeros([1, 1, Z, Y, X]).float().cuda(),
                        already_mem=True,
                        sigma=1)

                for n in list(range(2)):
                    ind = inds[-(n + 1)]
                    # print('plotting bad traj with cost %.2f' % (cost_neg[ind]))
                    xyzlist_e_mem = sampled_trajs_mem[:, ind]
                    # this is 1 x S x 3
                    summ_writer.summ_traj_on_occ(
                        'forecast/worst_sampled_traj%d' % n,
                        xyzlist_e_mem,
                        torch.max(occs[0:1], dim=1)[0],
                        # torch.zeros([1, 1, Z, Y, X]).float().cuda(),
                        already_mem=True,
                        sigma=1)
        else:

            # use some timesteps as input
            feat_input = feats[:, :self.num_given].squeeze(2)
            # feat_input is B x self.num_given x ZZ x ZY x ZX

            ## regular bottle3D
            # vel_e = self.regressor(feat_input)

            ## sparse-invar bottle3D
            comp_mask = 1.0 - (feat_input == 0).all(dim=1,
                                                    keepdim=True).float()
            summ_writer.summ_feat('forecast/feat_input', feat_input, pca=False)
            summ_writer.summ_feat('forecast/feat_comp_mask',
                                  comp_mask,
                                  pca=False)
            vel_e = self.regressor(feat_input, comp_mask)

            vel_e = vel_e.reshape(B, self.num_need, 3)
            vel_g = xyzlist_cam[:,
                                self.num_given:] - xyzlist_cam[:,
                                                               self.num_given -
                                                               1:-1]

            xyzlist_e = torch.zeros_like(xyzlist_cam)
            xyzlist_g = torch.zeros_like(xyzlist_cam)
            for s in list(range(S)):
                # print('s = %d' % s)
                if s < self.num_given:
                    # print('grabbing from gt ind %s' % s)
                    xyzlist_e[:, s] = xyzlist_cam[:, s]
                    xyzlist_g[:, s] = xyzlist_cam[:, s]
                else:
                    # print('grabbing from s-self.num_given, which is ind %d' % (s-self.num_given))
                    xyzlist_e[:,
                              s] = xyzlist_e[:, s - 1] + vel_e[:, s -
                                                               self.num_given]
                    xyzlist_g[:,
                              s] = xyzlist_g[:, s - 1] + vel_g[:, s -
                                                               self.num_given]

        xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X)
        xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X)
        summ_writer.summ_traj_on_occ('forecast/traj_e',
                                     xyzlist_e_mem,
                                     torch.max(occs, dim=1)[0],
                                     already_mem=True,
                                     sigma=2)
        summ_writer.summ_traj_on_occ('forecast/traj_g',
                                     xyzlist_g_mem,
                                     torch.max(occs, dim=1)[0],
                                     already_mem=True,
                                     sigma=2)

        scorelist_here = scorelist[:, self.num_given:, 0]
        sql2 = torch.sum((vel_g - vel_e)**2, dim=2)

        ## yes weightmask
        weightmask = torch.arange(0,
                                  self.num_need,
                                  dtype=torch.float32,
                                  device=torch.device('cuda'))
        weightmask = torch.exp(-weightmask**(1. / 4))
        # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860,
        #         0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353
        weightmask = weightmask.reshape(1, self.num_need)
        l2_loss = utils_basic.reduce_masked_mean(sql2,
                                                 scorelist_here * weightmask)

        utils_misc.add_loss('forecast/l2_loss', 0, l2_loss, 0, summ_writer)

        # # no weightmask:
        # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here)

        # total_loss = utils_misc.add_loss('forecast/l2_loss', total_loss, l2_loss, hyp.forecast_l2_coeff, summ_writer)

        dist = torch.norm(xyzlist_e - xyzlist_g, dim=2)
        meandist = utils_basic.reduce_masked_mean(dist, scorelist[:, :, 0])
        utils_misc.add_loss('forecast/xyz_dist_0', 0, meandist, 0, summ_writer)

        l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here)
        # utils_misc.add_loss('forecast/vel_dist_noexp', 0, l2_loss, 0, summ_writer)
        total_loss = utils_misc.add_loss('forecast/l2_loss_noexp', total_loss,
                                         l2_loss_noexp, hyp.forecast_l2_coeff,
                                         summ_writer)

        return total_loss
예제 #8
0
    def forward(self,
                feat0,
                feat1,
                flow_g=None,
                mask_g=None,
                summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()

        B, C, D, H, W = list(feat0.shape)
        utils_basic.assert_same_shape(feat0, feat1)

        # feats = torch.cat([feat0, feat1], dim=0)
        # feats = self.compressor(feats)
        # feats = utils_basic.l2_normalize(feats, dim=1)
        # feat0, feat1 = feats[:B], feats[B:]

        flow_total = torch.zeros([B, 3, D, H, W]).float().cuda()
        feat1_aligned = feat1.clone()

        # summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0, [feat0, feat1_aligned])
        feat_diff = torch.mean(
            utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True))
        utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0,
                            summ_writer)

        for sc in self.scales:
            flow = self.generate_flow(feat0, feat1_aligned, sc)

            mask = torch.zeros_like(flow[:, 0:1])
            # print('mask', mask.shape)
            mask[:, :, self.max_disp:-self.max_disp,
                 self.max_disp:-self.max_disp,
                 self.max_disp:-self.max_disp] = 1.0
            flow = flow * mask

            flow_total = flow_total + flow

            # compositional LK: warp the original thing using the cumulative flow
            feat1_aligned = utils_samp.backwarp_using_3D_flow(
                feat1, flow_total)
            valid1_region = utils_samp.backwarp_using_3D_flow(
                torch.ones_like(feat1[:, 0:1]), flow_total)
            # summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc, [feat0, feat1_aligned],
            #                        valids=[torch.ones_like(valid1_region), valid1_region])
            feat_diff = utils_basic.reduce_masked_mean(
                utils_basic.l2_on_axis((feat1_aligned - feat0),
                                       1,
                                       keepdim=True), valid1_region)
            utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff,
                                0, summ_writer)

        if flow_g is not None:
            # ok done inference
            # now for losses/metrics:
            l1_diff_3chan = self.smoothl1(flow_total, flow_g)
            l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True)
            l2_diff_3chan = self.mse(flow_total, flow_g)
            l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True)

            nonzero_mask = ((torch.sum(torch.abs(flow_g), axis=1, keepdim=True)
                             > 0.01).float()) * mask_g * mask
            yeszero_mask = (1.0 - nonzero_mask) * mask_g * mask
            l1_loss = utils_basic.reduce_masked_mean(l1_diff, mask_g * mask)
            l2_loss = utils_basic.reduce_masked_mean(l2_diff, mask_g * mask)
            l1_loss_nonzero = utils_basic.reduce_masked_mean(
                l1_diff, nonzero_mask)
            l1_loss_yeszero = utils_basic.reduce_masked_mean(
                l1_diff, yeszero_mask)
            l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5
            l2_loss_nonzero = utils_basic.reduce_masked_mean(
                l2_diff, nonzero_mask)
            l2_loss_yeszero = utils_basic.reduce_masked_mean(
                l2_diff, yeszero_mask)
            l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5

            # clip = np.squeeze(torch.max(torch.abs(torch.mean(flow_g[0], dim=0))).detach().cpu().numpy()).item()
            clip = 3.0
            if summ_writer is not None:
                # summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc, flow_total*mask_g, clip=clip)
                summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc,
                                         flow_total,
                                         clip=clip)
                summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc,
                                         flow_g,
                                         clip=clip)
                summ_writer.summ_oned('flow/mask_%.2f' % sc,
                                      mask,
                                      bev=True,
                                      norm=False)
                summ_writer.summ_feat('flow/flow_e_pca_%.2f' % sc,
                                      flow_total,
                                      pca=True)

            utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero, 0,
                                summ_writer)
            utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero, 0,
                                summ_writer)
            utils_misc.add_loss('flow/l1_loss_balanced', 0, l1_loss_balanced,
                                0, summ_writer)

            total_loss = utils_misc.add_loss('flow/l1_loss', total_loss,
                                             l1_loss, hyp.flow_l1_coeff,
                                             summ_writer)
            total_loss = utils_misc.add_loss('flow/l2_loss', total_loss,
                                             l2_loss, hyp.flow_l2_coeff,
                                             summ_writer)

            total_loss = utils_misc.add_loss('flow/warp', total_loss,
                                             feat_diff, hyp.flow_warp_coeff,
                                             summ_writer)

            # smooth loss
            dx, dy, dz = utils_basic.gradient3D(flow_total, absolute=True)
            smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True)
            if summ_writer is not None:
                summ_writer.summ_oned('flow/smooth_loss',
                                      torch.mean(smooth_vox, dim=3))
            smooth_loss = torch.mean(smooth_vox)
            total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss,
                                             smooth_loss,
                                             hyp.flow_smooth_coeff,
                                             summ_writer)

        return total_loss, flow_total
예제 #9
0
    def forward(self,
                pix_T_cam0,
                cam0_T_cam1,
                feat_mem1,
                rgb_g,
                vox_util,
                valid=None,
                summ_writer=None,
                test=False,
                suffix=''):
        total_loss = torch.tensor(0.0).cuda()

        B, C, H, W = list(rgb_g.shape)

        PH, PW = hyp.PH, hyp.PW
        if (PH < H) or (PW < W):
            # print('H, W', H, W)
            # print('PH, PW', PH, PW)
            sy = float(PH) / float(H)
            sx = float(PW) / float(W)
            pix_T_cam0 = utils_geom.scale_intrinsics(pix_T_cam0, sx, sy)

            if valid is not None:
                valid = F.interpolate(valid, scale_factor=0.5, mode='nearest')
            rgb_g = F.interpolate(rgb_g, scale_factor=0.5, mode='bilinear')

        # feat_prep = self.prep_layer(feat_mem1)
        # feat_proj = utils_vox.apply_pixX_T_memR_to_voxR(
        #     pix_T_cam0, cam0_T_cam1, feat_prep,
        #     hyp.view_depth, PH, PW)
        feat_proj = vox_util.apply_pixX_T_memR_to_voxR(pix_T_cam0, cam0_T_cam1,
                                                       feat_mem1,
                                                       hyp.view_depth, PH, PW)
        # logspace_slices=(hyp.dataset_name=='carla'))

        # def flatten_depth(feat_3d):
        #     B, C, Z, Y, X = list(feat_3d.shape)
        #     feat_2d = feat_3d.view(B, C*Z, Y, X)
        #     return feat_2d

        # feat_pool = self.pool_layer(feat_proj)
        # feat_im = flatten_depth(feat_pool)
        # rgb_e = self.decoder(feat_im)

        feat = self.net(feat_proj)
        rgb = self.rgb_layer(feat)
        emb = self.emb_layer(feat)
        emb = utils_basic.l2_normalize(emb, dim=1)

        # feat_im = self.net(feat_proj)
        # if hyp.do_emb2D:
        #     emb_e = self.emb_layer(feat)
        #     # postproc
        #     emb_e = l2_normalize(emb_e, dim=1)
        # else:
        #     emb_e = None

        if test:
            return None, rgb, None

        # loss_im = torch.mean(F.mse_loss(rgb, rgb_g, reduction='none'), dim=1, keepdim=True)
        loss_im = utils_basic.l1_on_axis(rgb - rgb_g, 1, keepdim=True)
        if valid is not None:
            rgb_loss = utils_basic.reduce_masked_mean(loss_im, valid)
        else:
            rgb_loss = torch.mean(loss_im)

        total_loss = utils_misc.add_loss('view/rgb_l1_loss', total_loss,
                                         rgb_loss, hyp.view_l1_coeff,
                                         summ_writer)

        # smooth loss
        dy, dx = utils_basic.gradient2D(rgb, absolute=True)
        smooth_im = torch.mean(dy + dx, dim=1, keepdims=True)
        if summ_writer is not None:
            summ_writer.summ_oned('view/smooth_loss', smooth_im)
        smooth_loss = torch.mean(smooth_im)
        total_loss = utils_misc.add_loss('view/smooth_loss', total_loss,
                                         smooth_loss, hyp.view_smooth_coeff,
                                         summ_writer)

        # vis
        if summ_writer is not None:
            summ_writer.summ_oned('view/rgb_loss', loss_im)
            summ_writer.summ_rgbs('view/rgb', [rgb.clamp(-0.5, 0.5), rgb_g])
            summ_writer.summ_rgb('view/rgb_e', rgb.clamp(-0.5, 0.5))
            summ_writer.summ_rgb('view/rgb_g', rgb_g.clamp(-0.5, 0.5))
            summ_writer.summ_feat('view/emb', emb, pca=True)
            if valid is not None:
                summ_writer.summ_rgb('view/rgb_e_valid',
                                     valid * rgb.clamp(-0.5, 0.5))
                summ_writer.summ_rgb('view/rgb_g_valid',
                                     valid * rgb_g.clamp(-0.5, 0.5))

        return total_loss, rgb, emb
    def forward(
        self,
        boxes_g,
        scores_g,
        feat_zyx,
        summ_writer,
        mask=None,
    ):
        if hyp.deeper_det:
            feat_zyx = self.resnet(feat_zyx)
        # st()
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = feat_zyx.shape
        _, N, _ = boxes_g.shape  # dim-2 is xc,yc,zc,lx,ly,lz,rx,ry,rz

        total_loss = 0.0
        pred_dim = self.pred_dim  # total 7, 6 deltas, 1 objectness

        feat = feat_zyx.permute(
            0, 1, 4, 3, 2)  # get feat in xyz order, now B x C x X x Y x Z

        corners = utils_geom.transform_boxes_to_corners(
            boxes_g)  # corners is B x N x 8 x 3, last dim in xyz order
        corners_max = torch.max(corners, dim=2)[0]  # B x N x 3
        corners_min = torch.min(corners, dim=2)[0]
        corners_min_max_g = torch.stack([corners_min, corners_max],
                                        dim=3)  # this is B x N x 3 x 2

        # trim down, to save some time
        N = hyp.K
        boxes_g = boxes_g[:, :N, :6]
        corners_min_max_g = corners_min_max_g[:, :N]
        scores_g = scores_g[:, :N]  # B x N

        # boxes_g is [-0.5~63.5, -0.5~15.5, -0.5~63.5]
        centers_g = boxes_g[:, :, :3]  # B x N x 3
        # centers_g is B x N x 3
        grid = meshgrid3D_xyz(
            B, Z, Y, X)[0]  # just one grid please, this is X x Y x Z x 3

        delta_positions_raw = centers_g.view(B, N, 1, 1, 1, 3) - grid.view(
            1, 1, X, Y, Z, 3)
        # tf.summary.histogram('delta_positions_raw', delta_positions_raw)
        delta_positions = delta_positions_raw / hyp.det_anchor_size
        # tf.summary.histogram('delta_positions', delta_positions)

        lengths_g = boxes_g[:, :, 3:6]  # B x N x 3
        # tf.summary.histogram('lengths_g', lengths_g)
        delta_lengths = torch.log(lengths_g / hyp.det_anchor_size)
        delta_lengths = torch.max(
            delta_lengths, -1e6 *
            torch.ones_like(delta_lengths))  # to avoid -infs turning into nans
        # tf.summary.histogram('delta_lengths', delta_lengths)

        lengths_g = lengths_g.view(B, N, 1, 1, 1,
                                   3).repeat(1, 1, X, Y, Z,
                                             1)  # B x N x X x Y x Z x 3
        delta_lengths = delta_lengths.view(B, N, 1, 1, 1, 3).repeat(
            1, 1, X, Y, Z, 1)  # B x N x X x Y x Z x 3
        valid_mask = scores_g.view(B, N, 1, 1, 1,
                                   1).repeat(1, 1, X, Y, Z,
                                             1)  # B x N x X x Y x Z x 1

        delta_gt = torch.cat([delta_positions, delta_lengths],
                             -1)  # B x N x X x Y x Z x 6

        object_dist = torch.max(torch.abs(delta_positions_raw) /
                                (lengths_g * 0.5 + 1e-5),
                                dim=5)[0]  # B x N x X x Y x Z
        object_dist_mask = (torch.ones_like(object_dist) -
                            binarize(object_dist, 0.5)).unsqueeze(
                                dim=5)  # B x N x X x Y x Z x 1
        object_dist_mask = object_dist_mask * valid_mask  # B x N x X x Y x Z x 1
        object_neg_dist_mask = torch.ones_like(object_dist) - binarize(
            object_dist, 0.8)
        object_neg_dist_mask = object_neg_dist_mask * valid_mask.squeeze(
            dim=5)  # B x N x X x Y x Z

        anchor_deltas_gt = None

        for obj_id in list(range(N)):
            if anchor_deltas_gt is None:
                anchor_deltas_gt = delta_gt[:,
                                            obj_id, :, :, :, :] * object_dist_mask[:,
                                                                                   obj_id, :, :, :, :]
                current_mask = object_dist_mask[:, obj_id, :, :, :, :]
            else:
                # don't overwrite anchor positions that are already taken
                overlap = current_mask * object_dist_mask[:,
                                                          obj_id, :, :, :, :]
                anchor_deltas_gt += (
                    torch.ones_like(overlap) - overlap
                ) * delta_gt[:,
                             obj_id, :, :, :, :] * object_dist_mask[:,
                                                                    obj_id, :, :, :, :]
                current_mask = current_mask + object_dist_mask[:,
                                                               obj_id, :, :, :, :]
                current_mask = binarize(current_mask, 0.5)

        # tf.summary.histogram('anchor_deltas_gt', anchor_deltas_gt)
        # ok nice, these do not have any extreme values

        pos_equal_one = binarize(torch.sum(object_dist_mask, dim=1),
                                 0.5).squeeze(dim=4)  # B x X x Y x Z
        neg_equal_one = binarize(torch.sum(object_neg_dist_mask, dim=1), 0.5)
        neg_equal_one = torch.ones_like(
            neg_equal_one) - neg_equal_one  # B x X x Y x Z
        pos_equal_one_sum = torch.sum(pos_equal_one, [1, 2, 3])  # B
        neg_equal_one_sum = torch.sum(neg_equal_one, [1, 2, 3])

        # set min to one in case no object, to avoid nan
        pos_equal_one_sum_safe = torch.max(
            pos_equal_one_sum, torch.ones_like(pos_equal_one_sum))  # B
        neg_equal_one_sum_safe = torch.max(
            neg_equal_one_sum, torch.ones_like(neg_equal_one_sum))  # B

        pred = self.conv1(feat)  # this is B x 7 x X x Y x Z
        pred = pred.permute(0, 2, 3, 4, 1)  # B x X x Y x Z x 7
        pred_anchor_deltas = pred[..., 1:]  # B x X x Y x Z x 6
        pred_objectness_logits = pred[..., 0]  # B x X x Y x Z
        pred_objectness = torch.nn.functional.sigmoid(
            pred_objectness_logits)  # B x X x Y x Z

        # pred_anchor_deltas = pred_anchor_deltas.cpu()
        # pred_objectness = pred_objectness.cpu()

        small_addon_for_BCE = 1e-6

        overall_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            input=pred_objectness_logits,
            target=pos_equal_one,
            reduction='none',
        )

        if mask is not None:
            overall_loss = overall_loss * mask
        else:
            overall_loss = overall_loss

        cls_pos_loss = utils_basic.reduce_masked_mean(overall_loss,
                                                      pos_equal_one)
        cls_neg_loss = utils_basic.reduce_masked_mean(overall_loss,
                                                      neg_equal_one)

        loss_prob = torch.sum(hyp.alpha_pos * cls_pos_loss +
                              hyp.beta_neg * cls_neg_loss)

        pos_mask = pos_equal_one.unsqueeze(dim=4)  # B x X x Y x Z x 1

        if mask is not None:
            loss_l1 = smooth_l1_loss(pos_mask * pred_anchor_deltas, pos_mask *
                                     anchor_deltas_gt)  # B x X x Y x Z x 1
            loss_l1 = loss_l1 * mask.unsqueeze(-1)
        else:
            loss_l1 = smooth_l1_loss(pos_mask * pred_anchor_deltas, pos_mask *
                                     anchor_deltas_gt)  # B x X x Y x Z x 1

        loss_reg = torch.sum(
            loss_l1 / pos_equal_one_sum_safe.view(-1, 1, 1, 1, 1)) / hyp.B

        total_loss = utils_misc.add_loss('det/detect_prob', total_loss,
                                         loss_prob, hyp.det_prob_coeff,
                                         summ_writer)
        total_loss = utils_misc.add_loss('det/detect_reg', total_loss,
                                         loss_reg, hyp.det_reg_coeff,
                                         summ_writer)

        # finally, turn the preds into hard boxes, with nms
        (
            bs_selected_boxes_co,
            bs_selected_scores,
            bs_overlaps,
        ) = rpn_proposal_graph(pred_objectness,
                               pred_anchor_deltas,
                               scores_g,
                               corners_min_max_g,
                               iou_thresh=0.2)
        # these are lists of length B, each one leading with dim "?", since there is a variable number of objs per frame

        N = hyp.K * 2
        tidlist = torch.linspace(1.0, N, N).long().to('cuda')
        tidlist = tidlist.unsqueeze(0).repeat(B, 1)
        padded_boxes_e = torch.zeros(B, N, 9).float().cuda()
        padded_scores_e = torch.zeros(B, N).float().cuda()
        if bs_selected_boxes_co is not None:

            for b in list(range(B)):

                # make the boxes 1 x N x 9 (instead of B x ? x 6)
                padded_boxes0_e = bs_selected_boxes_co[b].unsqueeze(0)
                padded_scores0_e = bs_selected_scores[b].unsqueeze(0)

                padded_boxes0_e = torch.cat([
                    padded_boxes0_e,
                    torch.zeros([1, N, 6], device=torch.device('cuda'))
                ],
                                            dim=1)  # 1 x ? x 6
                padded_scores0_e = torch.cat([
                    padded_scores0_e,
                    torch.zeros([1, N], device=torch.device('cuda'))
                ],
                                             dim=1)  # pad out

                padded_boxes0_e = padded_boxes0_e[:, :N]  # clip to N
                padded_scores0_e = padded_scores0_e[:, :N]  # clip to N

                padded_boxes0_e = torch.cat([
                    padded_boxes0_e,
                    torch.zeros([1, N, 3], device=torch.device('cuda'))
                ],
                                            dim=2)

                padded_boxes_e[b] = padded_boxes0_e[0]
                padded_scores_e[b] = padded_scores0_e[0]

        return total_loss, padded_boxes_e, padded_scores_e, tidlist, bs_selected_scores, bs_overlaps
예제 #11
0
    def run_test(self, feed):
        results = dict()

        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist(
            self.lrt_camX0s)

        self.original_centroid = self.scene_centroid.clone()

        obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s)
        obj_length = obj_lengths[:, 0]
        for b in list(range(self.B)):
            if self.score_s[b, 0] < 1.0:
                # we need the template to exist
                print('returning early, since score_s[%d,0] = %.1f' %
                      (b, self.score_s[b, 0].cpu().numpy()))
                return total_loss, results, True
            # if torch.sum(self.score_s[b]) < (self.S/2):
            if not (torch.sum(self.score_s[b]) == self.S):
                # the full traj should be valid
                print(
                    'returning early, since sum(score_s) = %d, while S = %d' %
                    (torch.sum(self.score_s).cpu().numpy(), self.S))
                return total_loss, results, True

        if hyp.do_feat3D:

            feat_memX0_input = torch.cat([
                self.occ_memX0s[:, 0],
                self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0],
            ],
                                         dim=1)
            _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input)
            B, C, Z, Y, X = list(feat_memX0.shape)
            S = self.S

            obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist(
                self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1)
            # only take the occupied voxels
            occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y,
                                                   X)
            # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0
            obj_mask_memX0 = obj_mask_memX0s[:, 0]

            # discard the known freespace
            _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs[:, 0:1],
                self.xyz_camXs[:, 0:1],
                Z,
                Y,
                X,
                agg=True)
            free_memX0 = free_memX0_.squeeze(1)
            obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0)

            for b in list(range(self.B)):
                if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8:
                    print(
                        'returning early, since there are not enough valid object points'
                    )
                    return total_loss, results, True

            # for b in list(range(self.B)):
            #     sum_b = torch.sum(obj_mask_memX0[b])
            #     print('sum_b', sum_b.detach().cpu().numpy())
            #     if sum_b > 1000:
            #         obj_mask_memX0[b] *= occ_memX0[b]
            #         sum_b = torch.sum(obj_mask_memX0[b])
            #         print('reducing this to', sum_b.detach().cpu().numpy())

            feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1)
            # this is B x C x huge
            feat0_vec = feat0_vec.permute(0, 2, 1)
            # this is B x huge x C

            obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round()
            occ_mask0_vec = occ_memX0.reshape(B, -1).round()
            free_mask0_vec = free_memX0.reshape(B, -1).round()
            # these are B x huge

            orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X)
            # this is B x huge x 3

            obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(
                self.lrt_camX0s)
            obj_length = obj_lengths[:, 0]
            cam0_T_obj = cams_T_obj0[:, 0]
            # this is B x S x 4 x 4

            mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X)
            cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X)

            lrt_camIs_g = self.lrt_camX0s.clone()
            lrt_camIs_e = torch.zeros_like(self.lrt_camX0s)
            # we will fill this up

            ious = torch.zeros([B, S]).float().cuda()
            point_counts = np.zeros([B, S])
            inb_counts = np.zeros([B, S])

            feat_vis = []
            occ_vis = []

            for s in range(self.S):
                if not (s == 0):
                    # remake the vox util and all the mem data
                    self.scene_centroid = utils_geom.get_clist_from_lrtlist(
                        lrt_camIs_e[:, s - 1:s])[:, 0]
                    delta = self.scene_centroid - self.original_centroid
                    self.vox_util = vox_util.Vox_util(
                        self.Z,
                        self.Y,
                        self.X,
                        self.set_name,
                        scene_centroid=self.scene_centroid,
                        assert_cube=True)
                    self.occ_memXs = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z,
                                                   self.Y, self.X))
                    self.occ_memX0s = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camX0s),
                                                   self.Z, self.Y, self.X))

                    self.unp_memXs = __u(
                        self.vox_util.unproject_rgb_to_mem(
                            __p(self.rgb_camXs), self.Z, self.Y, self.X,
                            __p(self.pix_T_cams)))
                    self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs(
                        self.camX0s_T_camXs, self.unp_memXs)
                    self.summ_writer.summ_occ('track/reloc_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                else:
                    self.summ_writer.summ_occ('track/init_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                    delta = torch.zeros([B, 3]).float().cuda()
                # print('scene centroid:', self.scene_centroid.detach().cpu().numpy())
                occ_vis.append(
                    self.summ_writer.summ_occ('',
                                              self.occ_memX0s[:, s],
                                              only_return=True))

                # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False))
                inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s],
                                                 self.Z4,
                                                 self.Y4,
                                                 self.X,
                                                 already_mem=False)
                num_inb = torch.sum(inb.float(), axis=1)
                # print('num_inb', num_inb, num_inb.shape)
                inb_counts[:, s] = num_inb.cpu().numpy()

                feat_memI_input = torch.cat([
                    self.occ_memX0s[:, s],
                    self.unp_memX0s[:, s] * self.occ_memX0s[:, s],
                ],
                                            dim=1)
                _, feat_memI, valid_memI = self.featnet3D(feat_memI_input)

                self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s,
                                           feat_memI_input,
                                           pca=True)
                self.summ_writer.summ_feat('3D_feats/feat_%d' % s,
                                           feat_memI,
                                           pca=True)
                feat_vis.append(
                    self.summ_writer.summ_feat('',
                                               feat_memI,
                                               pca=True,
                                               only_return=True))

                # collect freespace here, to discard bad matches
                _, free_memI_, _, _ = self.vox_util.prep_occs_supervision(
                    self.camX0s_T_camXs[:, s:s + 1],
                    self.xyz_camXs[:, s:s + 1],
                    Z,
                    Y,
                    X,
                    agg=True)
                free_memI = free_memI_.squeeze(1)

                feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1)
                # this is B x C x huge
                feat_vec = feat_vec.permute(0, 2, 1)
                # this is B x huge x C

                memI_T_mem0 = utils_geom.eye_4x4(B)
                # we will fill this up

                # # put these on cpu, to save mem
                # feat0_vec = feat0_vec.detach().cpu()
                # feat_vec = feat_vec.detach().cpu()

                # to simplify the impl, we will iterate over the batch dim
                for b in list(range(B)):
                    feat_vec_b = feat_vec[b]
                    feat0_vec_b = feat0_vec[b]
                    obj_mask0_vec_b = obj_mask0_vec[b]
                    occ_mask0_vec_b = occ_mask0_vec[b]
                    free_mask0_vec_b = free_mask0_vec[b]
                    orig_xyz_b = orig_xyz[b]
                    # these are huge x C

                    careful = False
                    if careful:
                        # start with occ points, since these are definitely observed
                        obj_inds_b = torch.where(
                            (occ_mask0_vec_b * obj_mask0_vec_b) > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # also take random non-free non-occ points in the mask
                        ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * (
                            1.0 - free_mask0_vec_b)
                        alt_inds_b = torch.where(ok_mask > 0)
                        alt_vec_b = feat0_vec_b[alt_inds_b]
                        alt_xyz0 = orig_xyz_b[alt_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        num = len(alt_xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts))
                            perm = np.random.permutation(num)
                            alt_vec_b = alt_vec_b[perm[:max_pts]]
                            alt_xyz0 = alt_xyz0[perm[:max_pts]]

                        obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0)
                        xyz0 = torch.cat([xyz0, alt_xyz0], dim=0)
                        if s == 0:
                            print('have %d pts in total' % (len(xyz0)))
                    else:
                        # take any points within the mask
                        obj_inds_b = torch.where(obj_mask0_vec_b > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        # trim down to max_pts
                        num = len(xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            print(
                                'have %d pts; taking a random set of %d pts inside'
                                % (num, max_pts))
                            perm = np.random.permutation(num)
                            obj_vec_b = obj_vec_b[perm[:max_pts]]
                            xyz0 = xyz0[perm[:max_pts]]

                    obj_vec_b = obj_vec_b.permute(1, 0)
                    # this is is C x med

                    corr_b = torch.matmul(feat_vec_b, obj_vec_b)
                    # this is huge x med

                    heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X)
                    # this is med x 1 x Z4 x Y4 x X4

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b_min = (torch.min(heat_b_).values)
                    # free_b = free_memI[b:b+1]
                    # print('free_b', free_b.shape)
                    # print('heat_b', heat_b.shape)
                    # heat_b[free_b > 0.0] = heat_b_min

                    # make the min zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_min
                    # zero out the freespace
                    heat_b = heat_b * (1.0 - free_memI[b:b + 1])
                    # make the max zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_max
                    # scale up, for numerical stability
                    heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True)
                    # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True)
                    # this is med x 3

                    xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y,
                                                     X)
                    xyzI_cam += delta
                    xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1)

                    memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI)

                    # record #points, since ransac depends on this
                    point_counts[b, s] = len(xyz0)
                # done stepping through batch

                mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0)
                cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI,
                                                  mem_T_cam)

                # eval
                camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0,
                                                 mem_T_cam, cam0_T_obj)
                # this is B x 4 x 4
                lrt_camIs_e[:,
                            s] = utils_geom.merge_lrt(obj_length, camI_T_obj)
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:,
                                                         s:s + 1]).squeeze(1)
            results['ious'] = ious
            # if ious[0,-1] > 0.5:
            #     print('returning early, since acc is too high')
            #     return total_loss, results, True

            self.summ_writer.summ_rgbs('track/feats', feat_vis)
            self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False)

            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())

            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())
            self.summ_writer.summ_scalar('track/point_counts',
                                         np.mean(point_counts))
            # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item())
            self.summ_writer.summ_scalar('track/inb_counts',
                                         np.mean(inb_counts))

            lrt_camX0s_e = lrt_camIs_e.clone()
            lrt_camXs_e = utils_geom.apply_4x4s_to_lrts(
                self.camXs_T_camX0s, lrt_camX0s_e)

            if self.include_vis:
                visX_e = []
                for s in list(range(self.S)):
                    visX_e.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s,
                                                      self.rgb_camXs[:, s],
                                                      lrt_camXs_e[:, s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e)
                visX_g = []
                for s in list(range(self.S)):
                    visX_g.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s,
                                                      self.rgb_camXs[:, s],
                                                      self.lrt_camXs[:,
                                                                     s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g)

            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)

            dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2)
            # this is B x S
            mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s)
            median_dist = utils_basic.reduce_masked_median(dists, self.score_s)
            # this is []
            self.summ_writer.summ_scalar('track/centroid_dist_mean',
                                         mean_dist.cpu().item())
            self.summ_writer.summ_scalar('track/centroid_dist_median',
                                         median_dist.cpu().item())

            # if self.include_vis:
            if (True):
                self.summ_writer.summ_traj_on_occ('track/traj_e',
                                                  obj_clist_camX0_e,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
                self.summ_writer.summ_traj_on_occ('track/traj_g',
                                                  self.obj_clist_camX0,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
            total_loss += mean_dist  # we won't backprop, but it's nice to plot and print this anyway

        else:
            ious = torch.zeros([self.B, self.S]).float().cuda()
            for s in list(range(self.S)):
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    self.lrt_camX0s[:, 0:1],
                    self.lrt_camX0s[:, s:s + 1]).squeeze(1)
            results['ious'] = ious
            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())
            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())

            lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1)
            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)
            self.summ_writer.summ_traj_on_occ('track/traj_e',
                                              obj_clist_camX0_e,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)
            self.summ_writer.summ_traj_on_occ('track/traj_g',
                                              self.obj_clist_camX0,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)

        self.summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, False
예제 #12
0
    def run_train(self, feed):
        results = dict()

        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        if hyp.do_feat3D:
            feat_memX0s_input = torch.cat([
                self.occ_memX0s,
                self.unp_memX0s * self.occ_memX0s,
            ],
                                          dim=2)
            feat3D_loss, feat_memX0s_, valid_memX0s_ = self.featnet3D(
                __p(feat_memX0s_input[:, 1:]),
                self.summ_writer,
            )
            feat_memX0s = __u(feat_memX0s_)
            valid_memX0s = __u(valid_memX0s_)
            total_loss += feat3D_loss

            feat_memX0 = utils_basic.reduce_masked_mean(
                feat_memX0s,
                valid_memX0s.repeat(1, 1, hyp.feat3D_dim, 1, 1, 1),
                dim=1)
            valid_memX0 = torch.sum(valid_memX0s, dim=1).clamp(0, 1)
            self.summ_writer.summ_feat('3D_feats/feat_memX0',
                                       feat_memX0,
                                       valid=valid_memX0,
                                       pca=True)
            self.summ_writer.summ_feat('3D_feats/valid_memX0',
                                       valid_memX0,
                                       pca=False)

            if hyp.do_emb3D:
                _, altfeat_memX0, altvalid_memX0 = self.featnet3D_slow(
                    feat_memX0s_input[:, 0])
                self.summ_writer.summ_feat('3D_feats/altfeat_memX0',
                                           altfeat_memX0,
                                           valid=altvalid_memX0,
                                           pca=True)
                self.summ_writer.summ_feat('3D_feats/altvalid_memX0',
                                           altvalid_memX0,
                                           pca=False)

        if hyp.do_emb3D:
            if hyp.do_feat3D:
                _, _, Z_, Y_, X_ = list(feat_memX0.shape)
            else:
                Z_, Y_, X_ = self.Z2, self.Y2, self.X2
            # Z_, Y_, X_ = self.Z, self.Y, self.X

            occ_memX0s, free_memX0s, _, _ = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=False)

            not_ok = torch.zeros_like(occ_memX0s[:, 0])
            # it's not ok for a voxel to be marked occ only once
            not_ok += (torch.sum(occ_memX0s, dim=1) == 1.0).float()
            # it's not ok for a voxel to be marked occ AND free
            occ_agg = torch.sum(occ_memX0s, dim=1).clamp(0, 1)
            free_agg = torch.sum(free_memX0s, dim=1).clamp(0, 1)
            have_either = (occ_agg + free_agg).clamp(0, 1)
            have_both = occ_agg * free_agg
            not_ok += have_either * have_both
            # it's not ok for a voxel to be totally unobserved
            not_ok += (have_either == 0.0).float()
            not_ok = not_ok.clamp(0, 1)
            self.summ_writer.summ_occ('rely/not_ok', not_ok)
            self.summ_writer.summ_occ(
                'rely/not_ok_occ',
                not_ok * torch.max(self.occ_memX0s_half, dim=1)[0])
            self.summ_writer.summ_occ(
                'rely/ok_occ',
                (1.0 - not_ok) * torch.max(self.occ_memX0s_half, dim=1)[0])
            self.summ_writer.summ_occ(
                'rely/aggressive_occ',
                torch.max(self.occ_memX0s_half, dim=1)[0])

            be_safe = False
            if hyp.do_feat3D and be_safe:
                # update the valid masks
                valid_memX0 = valid_memX0 * (1.0 - not_ok)
                altvalid_memX0 = altvalid_memX0 * (1.0 - not_ok)

        if hyp.do_occ:
            _, _, Z_, Y_, X_ = list(feat_memX0.shape)
            occ_memX0_sup, free_memX0_sup, _, free_memX0s = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=True)

            self.summ_writer.summ_occ('occ_sup/occ_sup', occ_memX0_sup)
            self.summ_writer.summ_occ('occ_sup/free_sup', free_memX0_sup)
            self.summ_writer.summ_occs('occ_sup/freeX0s_sup',
                                       torch.unbind(free_memX0s, dim=1))
            self.summ_writer.summ_occs(
                'occ_sup/occX0s_sup', torch.unbind(self.occ_memX0s_half,
                                                   dim=1))
            occ_loss, occ_memX0_pred = self.occnet(altfeat_memX0,
                                                   occ_memX0_sup,
                                                   free_memX0_sup,
                                                   altvalid_memX0,
                                                   self.summ_writer)
            total_loss += occ_loss

        if hyp.do_emb3D:
            # compute 3D ML
            emb_loss_3D = self.embnet3D(feat_memX0, altfeat_memX0,
                                        valid_memX0.round(),
                                        altvalid_memX0.round(),
                                        self.summ_writer)
            total_loss += emb_loss_3D

        self.summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, False
def get_gt_flow(obj_lrtlist_camRs,
                obj_scorelist,
                camRs_T_camXs,
                Z,
                Y,
                X,
                K=2,
                mod='',
                vis=True,
                summ_writer=None):
    # this constructs the flow field according to the given
    # box trajectories (obj_lrtlist_camRs) (collected from a moving camR)
    # and egomotion (encoded in camRs_T_camXs)
    # (so they do not take into account egomotion)
    # so, we first generate the flow for all the objects,
    # then in the background, put the ego flow
    N, B, S, D = list(obj_lrtlist_camRs.shape)
    assert (S == 2)  # as a flow util, this expects S=2

    flows = []
    masks = []
    for k in list(range(K)):
        obj_masklistR0 = utils_vox.assemble_padded_obj_masklist(
            obj_lrtlist_camRs[k, :, 0:1],
            obj_scorelist[k, :, 0:1],
            Z,
            Y,
            X,
            coeff=1.0)
        # this is B x 1(N) x 1(C) x Z x Y x Z
        # obj_masklistR0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X
        obj_mask0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X

        camR_T_cam0 = camRs_T_camXs[:, 0]
        camR_T_cam1 = camRs_T_camXs[:, 1]
        cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
        cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)
        # camR0_T_camR1 = camR0_T_camRs[:,1]
        # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1)

        # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0)
        # if vis and (summ_writer is not None):
        #     summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0)
        #     summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1)

        if vis and (summ_writer is not None):
            # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0)
            summ_writer.summ_oned('flow/obj%d_mask0_%s' % (k, mod),
                                  torch.mean(obj_mask0, 3))

        _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k])
        # this is B x S x 4 x 4
        ref_T_obj0 = ref_T_objs_list[:, 0]
        ref_T_obj1 = ref_T_objs_list[:, 1]
        obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0)
        obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1)
        # these are B x 4 x 4

        mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
        ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

        ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref)
        cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0,
                                          camR_T_cam0)
        mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

        xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
        xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

        xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
        xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

        # only use these displaced points within the obj mask
        # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3)
        obj_mask0 = obj_mask0.view(B, Z, Y, X, 1)
        # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0
        # cond = (obj_mask03 < 1.0).float()
        cond = (obj_mask0 > 0.0).float()
        xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0

        flow = xyz_mem1 - xyz_mem0
        flow = flow.permute(0, 4, 1, 2, 3)
        obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3)

        # if vis and k==0:
        if vis:
            summ_writer.summ_3D_flow('flow/gt_%d_%s' % (k, mod),
                                     flow,
                                     clip=4.0)

        masks.append(obj_mask0)
        flows.append(flow)

    camR_T_cam0 = camRs_T_camXs[:, 0]
    camR_T_cam1 = camRs_T_camXs[:, 1]
    cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
    cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)

    mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
    ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

    cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

    flow = xyz_mem1 - xyz_mem0
    flow = flow.permute(0, 4, 1, 2, 3)

    bkg_flow = flow

    # allow zero motion in the bkg
    any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0]
    masks.append(1.0 - any_mask)
    flows.append(bkg_flow)

    flows = torch.stack(flows, axis=0)
    masks = torch.stack(masks, axis=0)
    masks = masks.repeat(1, 1, 3, 1, 1, 1)
    flow = utils_basic.reduce_masked_mean(flows, masks, dim=0)

    if vis:
        summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0)

    # flow is shaped B x 3 x D x H x W
    return flow
예제 #14
0
    def forward(self, feat0, feat1, flow_g, mask_g, is_synth, summ_writer):
        total_loss = torch.tensor(0.0).cuda()

        B, C, D, H, W = list(feat0.shape)
        utils_basic.assert_same_shape(feat0, feat1)

        # feats = torch.cat([feat0, feat1], dim=0)
        # feats = self.compressor(feats)
        # feats = utils_basic.l2_normalize(feats, dim=1)
        # feat0, feat1 = feats[:B], feats[B:]

        flow_total_forw = torch.zeros_like(flow_g)
        flow_total_back = torch.zeros_like(flow_g)

        feat0_aligned = feat0.clone()
        feat1_aligned = feat1.clone()

        # cycle_losses = []
        # l1_losses = []

        # torch does not like it when we overwrite, so let's pre-allocate
        l1_loss_cumu = torch.tensor(0.0).cuda()
        l2_loss_cumu = torch.tensor(0.0).cuda()
        warp_loss_cumu = torch.tensor(0.0).cuda()

        summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0,
                               [feat0, feat1_aligned])
        feat_diff = torch.mean(
            utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True))
        utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0,
                            summ_writer)

        # print('feat0, feat1_aligned, mask')
        # print(feat0.shape)
        # print(feat1_aligned.shape)
        # print(mask_g.shape)

        hinge_loss_vox = utils_basic.l2_on_axis((feat1_aligned - feat0),
                                                1,
                                                keepdim=True)
        hinge_loss_vox = F.relu(0.2 - hinge_loss_vox)
        summ_writer.summ_oned('flow/hinge_loss',
                              torch.mean(hinge_loss_vox, dim=3))
        hinge_mask_vox = (torch.sum(torch.abs(flow_g), dim=1, keepdim=True) >
                          1.0).float()
        summ_writer.summ_oned('flow/hinge_mask',
                              torch.mean(hinge_mask_vox, dim=3),
                              norm=False)
        # hinge_loss = torch.mean(hinge_loss_vox)
        hinge_loss = utils_basic.reduce_masked_mean(hinge_loss_vox,
                                                    hinge_mask_vox)
        total_loss = utils_misc.add_loss('flow/hinge', total_loss, hinge_loss,
                                         hyp.flow_hinge_coeff, summ_writer)

        for sc in self.scales:

            # flow_forw, new_feat0, new_feat1 = self.generate_flow(feat0, feat1_aligned, sc)
            # flow_back, new_feat1, new_feat1 = self.generate_flow(feat1, feat0_aligned, sc)
            # flow_forw, heat = self.generate_flow(feat0, feat1_aligned, sc)

            flow_forw = self.generate_flow(feat0, feat1_aligned, sc)
            flow_back = self.generate_flow(feat1, feat0_aligned, sc)

            flow_total_forw = flow_total_forw + flow_forw
            flow_total_back = flow_total_back + flow_back

            # compositional LK: warp the original thing using the cumulative flow
            feat1_aligned = utils_samp.backwarp_using_3D_flow(
                feat1, flow_total_forw)
            feat0_aligned = utils_samp.backwarp_using_3D_flow(
                feat0, flow_total_back)
            valid1_region = utils_samp.backwarp_using_3D_flow(
                torch.ones_like(feat1[:, 0:1]), flow_total_forw)
            valid0_region = utils_samp.backwarp_using_3D_flow(
                torch.ones_like(feat0[:, 0:1]), flow_total_forw)

            summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc,
                                   [feat0, feat1_aligned],
                                   valids=[valid0_region, valid1_region])
            # summ_writer.summ_oned('flow/mean_heat_%.2f' % sc, torch.mean(heat, dim=3))
            # feat_diff = torch.mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True))
            feat_diff = utils_basic.reduce_masked_mean(
                utils_basic.l2_on_axis((feat1_aligned - feat0),
                                       1,
                                       keepdim=True),
                valid1_region * valid0_region)
            utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff,
                                0, summ_writer)

            if sc == 1.0:

                warp_loss_cumu = warp_loss_cumu + feat_diff * sc

                l1_diff_3chan = self.smoothl1(flow_total_forw, flow_g)
                l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True)
                l2_diff_3chan = self.mse(flow_total_forw, flow_g)
                l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True)

                nonzero_mask = (
                    (torch.sum(torch.abs(flow_g), axis=1, keepdim=True) >
                     0.01).float()) * mask_g
                yeszero_mask = (1.0 - nonzero_mask) * mask_g
                l1_loss_nonzero = utils_basic.reduce_masked_mean(
                    l1_diff, nonzero_mask)
                l1_loss_yeszero = utils_basic.reduce_masked_mean(
                    l1_diff, yeszero_mask)
                l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5
                l2_loss_nonzero = utils_basic.reduce_masked_mean(
                    l2_diff, nonzero_mask)
                l2_loss_yeszero = utils_basic.reduce_masked_mean(
                    l2_diff, yeszero_mask)
                l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5
                # l1_loss_cumu = l1_loss_cumu + l1_loss_balanced*sc

                l1_loss_cumu = l1_loss_cumu + l1_loss_balanced * sc
                l2_loss_cumu = l2_loss_cumu + l2_loss_balanced * sc

                # warp flow
                flow_back_aligned_to_forw = utils_samp.backwarp_using_3D_flow(
                    flow_total_back, flow_total_forw.detach())
                flow_forw_aligned_to_back = utils_samp.backwarp_using_3D_flow(
                    flow_total_forw, flow_total_back.detach())

                cancelled_flow_forw = flow_total_forw + flow_back_aligned_to_forw
                cancelled_flow_back = flow_total_back + flow_forw_aligned_to_back

                cycle_forw = self.smoothl1_mean(
                    cancelled_flow_forw, torch.zeros_like(cancelled_flow_forw))
                cycle_back = self.smoothl1_mean(
                    cancelled_flow_back, torch.zeros_like(cancelled_flow_back))
                cycle_loss = cycle_forw + cycle_back
                total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss,
                                                 cycle_loss,
                                                 hyp.flow_cycle_coeff,
                                                 summ_writer)

                summ_writer.summ_3D_flow('flow/flow_e_forw_%.2f' % sc,
                                         flow_total_forw * mask_g,
                                         clip=0.0)
                summ_writer.summ_3D_flow('flow/flow_e_back_%.2f' % sc,
                                         flow_total_back,
                                         clip=0.0)
                summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc,
                                         flow_g,
                                         clip=0.0)

                utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero,
                                    0, summ_writer)
                utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero,
                                    0, summ_writer)
                utils_misc.add_loss('flow/l1_loss_balanced', 0,
                                    l1_loss_balanced, 0, summ_writer)
                utils_misc.add_loss('flow/l2_loss_balanced', 0,
                                    l2_loss_balanced, 0, summ_writer)

                # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer)
                # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer)
                # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, l1_loss, hyp.flow_l1_coeff*(sc==1.0), summ_writer)

        if is_synth:
            total_loss = utils_misc.add_loss('flow/synth_l1_cumu', total_loss,
                                             l1_loss_cumu,
                                             hyp.flow_synth_l1_coeff,
                                             summ_writer)
            total_loss = utils_misc.add_loss('flow/synth_l2_cumu', total_loss,
                                             l2_loss_cumu,
                                             hyp.flow_synth_l2_coeff,
                                             summ_writer)
        else:
            total_loss = utils_misc.add_loss('flow/l1_cumu', total_loss,
                                             l1_loss_cumu, hyp.flow_l1_coeff,
                                             summ_writer)
            total_loss = utils_misc.add_loss('flow/l2_cumu', total_loss,
                                             l2_loss_cumu, hyp.flow_l2_coeff,
                                             summ_writer)

        # total_loss = utils_misc.add_loss('flow/warp', total_loss, feat_diff, hyp.flow_warp_coeff, summ_writer)
        total_loss = utils_misc.add_loss('flow/warp_cumu', total_loss,
                                         warp_loss_cumu, hyp.flow_warp_coeff,
                                         summ_writer)

        # feat1_aligned = utils_samp.backwarp_using_3D_flow(feat1, flow_g)
        # valid_region = utils_samp.backwarp_using_3D_flow(torch.ones_like(feat1[:,0:1]), flow_g)
        # summ_writer.summ_feats('flow/feats_aligned_g', [feat0, feat1_aligned],
        #                        valids=[valid_region, valid_region])
        # feat_diff = utils_basic.reduce_masked_mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True), valid_region)
        # total_loss = utils_misc.add_loss('flow/warp_g', total_loss, feat_diff, hyp.flow_warp_g_coeff, summ_writer)
        # # hinge_loss_vox = F.relu(0.2 - hinge_loss_vox)

        # total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss, torch.sum(torch.stack(cycle_losses)), hyp.flow_cycle_coeff, summ_writer)
        # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, torch.sum(torch.stack(l1_losses)), hyp.flow_l1_coeff, summ_writer)

        # smooth loss
        dx, dy, dz = utils_basic.gradient3D(flow_total_forw, absolute=True)
        smooth_vox_forw = torch.mean(dx + dy + dx, dim=1, keepdims=True)
        dx, dy, dz = utils_basic.gradient3D(flow_total_back, absolute=True)
        smooth_vox_back = torch.mean(dx + dy + dx, dim=1, keepdims=True)
        summ_writer.summ_oned('flow/smooth_loss_forw',
                              torch.mean(smooth_vox_forw, dim=3))
        smooth_loss = torch.mean((smooth_vox_forw + smooth_vox_back) * 0.5)
        total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss,
                                         smooth_loss, hyp.flow_smooth_coeff,
                                         summ_writer)

        # flow_e = F.sigmoid(flow_e_)
        # flow_e_binary = torch.round(flow_e)

        # # collect some accuracy stats
        # flow_match = flow_g*torch.eq(flow_e_binary, flow_g).float()
        # free_match = free_g*torch.eq(1.0-flow_e_binary, free_g).float()
        # either_match = torch.clamp(flow_match+free_match, 0.0, 1.0)
        # either_have = torch.clamp(flow_g+free_g, 0.0, 1.0)
        # acc_flow = reduce_masked_mean(flow_match, flow_g*valid)
        # acc_free = reduce_masked_mean(free_match, free_g*valid)
        # acc_total = reduce_masked_mean(either_match, either_have*valid)

        # summ_writer.summ_scalar('flow/acc_flow', acc_flow.cpu().item())
        # summ_writer.summ_scalar('flow/acc_free', acc_free.cpu().item())
        # summ_writer.summ_scalar('flow/acc_total', acc_total.cpu().item())

        # # vis
        # summ_writer.summ_flow('flow/flow_g', flow_g)
        # summ_writer.summ_flow('flow/free_g', free_g)
        # summ_writer.summ_flow('flow/flow_e', flow_e)
        # summ_writer.summ_flow('flow/valid', valid)

        # prob_loss = self.compute_loss(flow_e_, flow_g, free_g, valid, summ_writer)
        # total_loss = utils_misc.add_loss('flow/prob_loss', total_loss, prob_loss, hyp.flow_coeff, summ_writer)

        # return total_loss, flow_e
        return total_loss, flow_total_forw
예제 #15
0
    def forward(self,
                feat,
                obj_lrtlist_cams,
                obj_scorelist_s,
                summ_writer,
                suffix=''):
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = list(feat.shape)
        K, B2, S, D = list(obj_lrtlist_cams.shape)
        assert (B == B2)

        # obj_scorelist_s is K x B x S

        # __p = lambda x: utils_basic.pack_seqdim(x, B)
        # __u = lambda x: utils_basic.unpack_seqdim(x, B)

        # obj_lrtlist_cams is K x B x S x 19
        obj_lrtlist_cams_ = obj_lrtlist_cams.reshape(K * B, S, 19)
        obj_clist_cam_ = utils_geom.get_clist_from_lrtlist(obj_lrtlist_cams_)
        obj_clist_cam = obj_clist_cam_.reshape(K, B, S, 1, 3)
        # obj_clist_cam is K x B x S x 1 x 3
        obj_clist_cam = obj_clist_cam.squeeze(3)
        # obj_clist_cam is K x B x S x 3
        clist_cam = obj_clist_cam.reshape(K * B, S, 3)
        clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # this is K*B x S x 3
        clist_mem = clist_mem.reshape(K, B, S, 3)

        energy_vol = self.conv3d(feat)
        # energy_vol is B x 1 x Z x Y x X
        summ_writer.summ_oned('pri/energy_vol', torch.mean(energy_vol, dim=3))
        summ_writer.summ_histogram('pri/energy_vol_hist', energy_vol)

        # for k in range(K):
        # let's start with the first object
        # loglike_per_traj = self.get_traj_loglike(clist_mem[0], energy_vol)
        # # this is B
        # ce_loss = -1.0*torch.mean(loglike_per_traj)
        # # this is []

        loglike_per_traj = self.get_trajs_loglike(clist_mem, obj_scorelist_s,
                                                  energy_vol)
        # this is B x K
        valid = torch.max(obj_scorelist_s.permute(1, 0, 2), dim=2)[0]
        ce_loss = -1.0 * utils_basic.reduce_masked_mean(
            loglike_per_traj, valid)
        # this is []

        total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss,
                                         hyp.pri_ce_coeff, summ_writer)

        reg_loss = torch.sum(torch.abs(energy_vol))
        total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss,
                                         hyp.pri_reg_coeff, summ_writer)

        # smooth loss
        dz, dy, dx = gradient3D(energy_vol, absolute=True)
        smooth_vox = torch.mean(dx + dy + dx, dim=1, keepdims=True)
        summ_writer.summ_oned('pri/smooth_loss', torch.mean(smooth_vox, dim=3))
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss,
                                         smooth_loss, hyp.pri_smooth_coeff,
                                         summ_writer)

        # pri_e = F.sigmoid(energy_vol)
        # energy_volbinary = torch.round(pri_e)

        # # collect some accuracy stats
        # pri_match = pri_g*torch.eq(energy_volbinary, pri_g).float()
        # free_match = free_g*torch.eq(1.0-energy_volbinary, free_g).float()
        # either_match = torch.clamp(pri_match+free_match, 0.0, 1.0)
        # either_have = torch.clamp(pri_g+free_g, 0.0, 1.0)
        # acc_pri = reduce_masked_mean(pri_match, pri_g*valid)
        # acc_free = reduce_masked_mean(free_match, free_g*valid)
        # acc_total = reduce_masked_mean(either_match, either_have*valid)

        # summ_writer.summ_scalar('pri/acc_pri%s' % suffix, acc_pri.cpu().item())
        # summ_writer.summ_scalar('pri/acc_free%s' % suffix, acc_free.cpu().item())
        # summ_writer.summ_scalar('pri/acc_total%s' % suffix, acc_total.cpu().item())

        # # vis
        # summ_writer.summ_pri('pri/pri_g%s' % suffix, pri_g, reduce_axes=[2,3])
        # summ_writer.summ_pri('pri/free_g%s' % suffix, free_g, reduce_axes=[2,3])
        # summ_writer.summ_pri('pri/pri_e%s' % suffix, pri_e, reduce_axes=[2,3])
        # summ_writer.summ_pri('pri/valid%s' % suffix, valid, reduce_axes=[2,3])

        # prob_loss = self.compute_loss(energy_vol, pri_g, free_g, valid, summ_writer)
        # total_loss = utils_misc.add_loss('pri/prob_loss%s' % suffix, total_loss, prob_loss, hyp.pri_coeff, summ_writer)

        return total_loss  #, pri_e