Esempio n. 1
0
    def compute_loss(self, feat_map, pred_map, traj_past, traj_futu,
                     traj_futu_e, energy_map):
        # traj_past is B x T+1 x 2
        # traj_futu is B x T x 2
        # traj_futu_e is K x B x T x 2
        # energy_map is B x H x W x 1

        K, B, T_futu, D = list(traj_futu_e.shape)

        # ------------
        # forward loss
        # ------------
        # run the actual traj_futu through the net, to get its log prob
        CE_pqs = -1.0 * self.log_prob(feat_map, pred_map, traj_past, traj_futu)
        # this is B
        CE_pq = torch.mean(CE_pqs)
        # summ_writer.summ_scalar('rpo/CE_pq', CE_pq.cpu().item())

        # ------------
        # reverse loss
        # ------------
        B, C, Z, X = list(energy_map.shape)
        xyz_cam = traj_futu_e.reshape([B * K, T_futu, 2])
        xyz_mem = utils_vox.Ref2Mem(self.add_fake_y(xyz_cam), Z, 10, X)
        # since we have multiple samples per image, we need to tile up energy_map
        energy_map = energy_map.unsqueeze(0).repeat(K, 1, 1, 1,
                                                    1).reshape(K * B, C, Z, X)
        CE_qphat_all = -1.0 * utils_misc.get_traj_loglike(xyz_mem, energy_map)
        # this is B*K x T_futu
        CE_qphat = torch.mean(CE_qphat_all)

        _forward_loss = CE_pq
        _reverse_loss = CE_qphat
        return _forward_loss, _reverse_loss
Esempio n. 2
0
    def forward(self, feat_mem, clist_cam, summ_writer, suffix=''):
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = list(feat_mem.shape)
        B2, S, D = list(clist_cam.shape)
        assert (B == B2)
        assert (D == 3)

        clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # this is (still) B x S x 3

        feat_ = feat_mem.permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X)
        mask_ = 1.0 - (feat_ == 0).all(dim=1, keepdim=True).float().cuda()
        grid_ = utils_basic.meshgrid2D(B, Z, X, stack=True,
                                       norm=True).permute(0, 3, 1, 2)
        halfgrid_ = utils_basic.meshgrid2D(B,
                                           int(Z / 2),
                                           int(X / 2),
                                           stack=True,
                                           norm=True).permute(0, 3, 1, 2)
        feat_ = torch.cat([feat_, grid_], dim=1)
        energy_map, mask = self.net(feat_, mask_, halfgrid_)
        # energy_map = self.net(feat_)
        # energy_map is B x 1 x Z x X
        # don't do this: # energy_map = energy_map + (1.0-mask) * (torch.min(torch.min(energy_map, dim=2)[0], dim=2)[0]).reshape(B, 1, 1, 1)
        summ_writer.summ_feat('pri/energy_input', feat_)
        summ_writer.summ_oned('pri/energy_map', energy_map)
        summ_writer.summ_oned('pri/mask', mask, norm=False)
        summ_writer.summ_histogram('pri/energy_map_hist', energy_map)

        loglike_per_traj = utils_misc.get_traj_loglike(
            clist_mem * 0.5, energy_map)  # 0.5 since it's half res
        # loglike_per_traj = self.get_traj_loglike(clist_mem*0.25, energy_map) # 0.25 since it's quarter res
        # this is B x K
        ce_loss = -1.0 * torch.mean(loglike_per_traj)
        # this is []

        total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss,
                                         hyp.pri2D_ce_coeff, summ_writer)

        reg_loss = torch.sum(torch.abs(energy_map))
        total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss,
                                         hyp.pri2D_reg_coeff, summ_writer)

        # smooth loss
        dz, dx = utils_basic.gradient2D(energy_map, absolute=True)
        smooth_vox = torch.mean(dz + dx, dim=1, keepdims=True)
        summ_writer.summ_oned('pri/smooth_loss', smooth_vox)
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss,
                                         smooth_loss, hyp.pri2D_smooth_coeff,
                                         summ_writer)

        return total_loss, energy_map
Esempio n. 3
0
    def prep_dyn_feat(self, pred_map, traj_past):
        B, C, Z, X = list(pred_map.shape)
        B, T, D = list(traj_past.shape)

        # as the "dynamic" input, we will sample from pred_map at each loc in traj_past

        traj_past_mem = utils_vox.Ref2Mem(self.add_fake_y(traj_past), Z, 10, X)
        x = traj_past_mem[:, :, 0]
        z = traj_past_mem[:, :, 2]

        feats = utils_samp.bilinear_sample2D(pred_map, x, z)
        # print('sampled these:', feats.shape)
        # this is B x T x C
        dyn_feat = feats.reshape(B, -1)
        # also, we will concat the actual traj_past
        dyn_feat = torch.cat([dyn_feat, traj_past.reshape(B, -1)], axis=1)
        # print('cat the traj itself, and got:', dyn_feat.shape)
        return dyn_feat
    def summ_traj_on_occ(self, name, traj, occ_mem, already_mem=False, sigma=2):
        # traj is B x S x 3
        B, C, Z, Y, X = list(occ_mem.shape)
        B2, S, D = list(traj.shape)
        assert(D==3)
        assert(B==B2)
        
        if self.save_this:
            if already_mem:
                traj_mem = traj
            else:
                traj_mem = utils_vox.Ref2Mem(traj, Z, Y, X)

            height_mem = convert_occ_to_height(occ_mem, reduce_axis=3)
            # this is B x C x Z x X
            
            occ_vis = normalize(height_mem)
            occ_vis = oned2inferno(occ_vis, norm=False)
            # print(vis.shape)

            x, y, z = torch.unbind(traj_mem, dim=2)
            xz = torch.stack([x,z], dim=2)
            heats = draw_circles_at_xy(xz, Z, X, sigma=sigma)
            # this is B x S x 1 x Z x X
            heats = torch.squeeze(heats, dim=2)
            heat = seq2color(heats)


            # make black 0
            heat = back2color(heat)

            # print(heat.shape)
            # vis[heat > 0] = heat
            
            # replace black with occ vis
            heat[heat==0] = (occ_vis[heat==0].float()*0.5).byte() # darken the bkg a bit
            heat = preprocess_color(heat)
            
            self.summ_rgb(('%s' % (name)), heat)
Esempio n. 5
0
    def __init__(self):
        super(ForecastNet, self).__init__()

        print('ForecastNet...')

        self.use_cost_vols = False
        if self.use_cost_vols:

            if hyp.do_feat:
                in_dim = 66
            else:
                in_dim = 10

            hidden_dim = 32
            out_dim = hyp.S

            self.cost_forecaster = archs.encoder3D.Net3D(
                in_channel=in_dim, pred_dim=out_dim).cuda()
            # self.cost_forecaster = archs.encoder3D.ResNet3D(
            #     in_channel=in_dim, pred_dim=out_dim).cuda()

            # self.cost_forecaster = nn.Sequential(
            #     nn.ReplicationPad3d(1),
            #     nn.Conv3d(in_channels=in_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=0),
            #     nn.BatchNorm3d(num_features=hidden_dim),
            #     nn.LeakyReLU(),
            #     nn.ReplicationPad3d(1),
            #     nn.Conv3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=0),
            #     nn.BatchNorm3d(num_features=hidden_dim),
            #     nn.LeakyReLU(),
            #     nn.ConvTranspose3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=1),
            #     nn.BatchNorm3d(num_features=hidden_dim),
            #     nn.LeakyReLU(),
            #     nn.ConvTranspose3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=1),
            #     nn.BatchNorm3d(num_features=hidden_dim),
            #     nn.LeakyReLU(),
            #     nn.Conv3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, stride=1, padding=1),
            #     nn.BatchNorm3d(num_features=hidden_dim),
            #     nn.LeakyReLU(),
            #     nn.Conv3d(in_channels=hidden_dim, out_channels=out_dim, kernel_size=1, stride=1),
            # ).cuda()

            library_mod = 'ab'
            traj_library_cam = np.load('../intphys_data/all_trajs_%s.npy' %
                                       library_mod)
            # traj_library_cam is L x 100, where L is huge

            # let's make it a bit more huge
            traj_library_cam = np.concatenate([
                traj_library_cam * 0.8, traj_library_cam * 1.0,
                traj_library_cam * 1.2
            ],
                                              axis=0)

            # print('traj_library_cam', traj_library_cam.shape)
            self.L = traj_library_cam.shape[0]

            self.frame_stride = 2
            traj_library_cam = traj_library_cam[:, ::self.frame_stride]
            # traj_library_cam is L x M x 3
            self.M = traj_library_cam.shape[1]

            Z, Y, X = hyp.Z, hyp.Y, hyp.X
            Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
            traj_library_cam = torch.from_numpy(
                traj_library_cam).float().cuda()
            traj_library_cam_ = traj_library_cam.reshape(1, self.L * self.M, 3)
            traj_library_mem_ = utils_vox.Ref2Mem(traj_library_cam_, Z2, Y2,
                                                  X2)
            traj_library_mem = traj_library_mem_.reshape(self.L, self.M, 3)
            # print('traj_library_mem', traj_library_mem.shape)
            # self.traj_library_mem = traj_library_mem
            self.traj_library_mem = traj_library_mem.detach().cpu().numpy()

            # self.traj_library_mem is L x M x 3
        else:
            self.num_given = 3
            self.num_need = hyp.S - self.num_given

            in_dim = 3
            out_dim = self.num_need * 3

            # self.regressor = archs.bottle3D.Bottle3D(
            #     in_channel=in_dim, pred_dim=out_dim).cuda()
            self.regressor = archs.sparse_invar_bottle3D.Bottle3D(
                in_channel=in_dim, pred_dim=out_dim).cuda()
Esempio n. 6
0
    def forward(self,
                feats,
                xyzlist_cam,
                scorelist,
                vislist,
                occs,
                summ_writer,
                suffix=''):
        total_loss = torch.tensor(0.0).cuda()
        B, S, C, Z2, Y2, X2 = list(feats.shape)
        B, S, C, Z, Y, X = list(occs.shape)
        B2, S2, D = list(xyzlist_cam.shape)
        assert (B == B2, S == S2)
        assert (D == 3)

        xyzlist_mem = utils_vox.Ref2Mem(xyzlist_cam, Z, Y, X)
        # these are B x S x 3
        scorelist = scorelist.unsqueeze(2)
        # this is B x S x 1
        vislist = vislist[:, 0].reshape(B, 1, 1)
        # we only care that the object was visible in frame0
        scorelist = scorelist * vislist

        if self.use_cost_vols:
            if summ_writer.save_this:
                summ_writer.summ_traj_on_occ('forecast/actual_traj',
                                             xyzlist_mem * scorelist,
                                             torch.max(occs, dim=1)[0],
                                             already_mem=True,
                                             sigma=2)

            Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
            Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
            occ_hint0 = utils_vox.voxelize_xyz(xyzlist_cam[:, 0:1], Z4, Y4, X4)
            occ_hint1 = utils_vox.voxelize_xyz(xyzlist_cam[:, 1:2], Z4, Y4, X4)
            occ_hint0 = occ_hint0 * scorelist[:, 0].reshape(B, 1, 1, 1, 1)
            occ_hint1 = occ_hint1 * scorelist[:, 1].reshape(B, 1, 1, 1, 1)
            occ_hint = torch.cat([occ_hint0, occ_hint1], dim=1)
            occ_hint = F.interpolate(occ_hint, scale_factor=4, mode='nearest')
            # this is B x 1 x Z x Y x X
            summ_writer.summ_occ('forecast/occ_hint',
                                 (occ_hint0 + occ_hint1).clamp(0, 1))

            crops = []
            for s in list(range(S)):
                crop = utils_vox.center_mem_on_xyz(occs_highres[:, s],
                                                   xyzlist_cam[:,
                                                               s], Z2, Y2, X2)
                crops.append(crop)
            crops = torch.stack(crops, dim=0)
            summ_writer.summ_occs('forecast/crops', crops)

            # condition on the occ_hint
            feat = torch.cat([feat, occ_hint], dim=1)

            N = hyp.forecast_num_negs
            sampled_trajs_mem = self.sample_trajs_from_library(N, xyzlist_mem)

            if summ_writer.save_this:
                for n in list(range(np.min([N, 10]))):
                    xyzlist_mem = sampled_trajs_mem[0, n].unsqueeze(0)
                    # this is 1 x S x 3
                    summ_writer.summ_traj_on_occ(
                        'forecast/lib%d_xyzlist' % n,
                        xyzlist_mem,
                        torch.zeros([1, 1, Z, Y, X]).float().cuda(),
                        already_mem=True)

            cost_vols = self.cost_forecaster(feat)
            # cost_vols = F.sigmoid(cost_vols)
            cost_vols = F.interpolate(cost_vols,
                                      scale_factor=2,
                                      mode='trilinear')

            # cost_vols is B x S x Z x Y x X
            summ_writer.summ_histogram('forecast/cost_vols_hist', cost_vols)
            cost_vols = cost_vols.clamp(
                -1000, 1000)  # raquel says this adds stability
            summ_writer.summ_histogram('forecast/cost_vols_clamped_hist',
                                       cost_vols)

            cost_vols_vis = torch.mean(cost_vols, dim=3).unsqueeze(2)
            # cost_vols_vis is B x S x 1 x Z x X
            summ_writer.summ_oneds('forecast/cost_vols_vis',
                                   torch.unbind(cost_vols_vis, dim=1))

            # smooth loss
            cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X)
            dz, dy, dx = gradient3D(cost_vols_, absolute=True)
            dt = torch.abs(cost_vols[:, 1:] - cost_vols[:, 0:-1])
            smooth_vox_spatial = torch.mean(dx + dy + dz, dim=1, keepdims=True)
            smooth_vox_time = torch.mean(dt, dim=1, keepdims=True)
            summ_writer.summ_oned('forecast/smooth_loss_spatial',
                                  torch.mean(smooth_vox_spatial, dim=3))
            summ_writer.summ_oned('forecast/smooth_loss_time',
                                  torch.mean(smooth_vox_time, dim=3))
            smooth_loss = torch.mean(smooth_vox_spatial) + torch.mean(
                smooth_vox_time)
            total_loss = utils_misc.add_loss('forecast/smooth_loss',
                                             total_loss, smooth_loss,
                                             hyp.forecast_smooth_coeff,
                                             summ_writer)

            def clamp_xyz(xyz, X, Y, Z):
                x, y, z = torch.unbind(xyz, dim=-1)
                x = x.clamp(0, X)
                y = x.clamp(0, Y)
                z = x.clamp(0, Z)
                xyz = torch.stack([x, y, z], dim=-1)
                return xyz

            # obj_xyzlist_mem is K x B x S x 3
            # xyzlist_mem is B x S x 3
            # sampled_trajs_mem is B x N x S x 3
            xyz_pos_ = xyzlist_mem.reshape(B * S, 1, 3)
            xyz_neg_ = sampled_trajs_mem.permute(0, 2, 1,
                                                 3).reshape(B * S, N, 3)
            # xyz_pos_ = clamp_xyz(xyz_pos_, X, Y, Z)
            # xyz_neg_ = clamp_xyz(xyz_neg_, X, Y, Z)
            xyz_ = torch.cat([xyz_pos_, xyz_neg_], dim=1)
            xyz_ = clamp_xyz(xyz_, X, Y, Z)
            cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X)
            x, y, z = torch.unbind(xyz_, dim=2)
            # x = x.clamp(0, X)
            # y = x.clamp(0, Y)
            # z = x.clamp(0, Z)
            cost_ = utils_samp.bilinear_sample3D(cost_vols_, x, y,
                                                 z).squeeze(1)
            # cost is B*S x 1+N
            cost_pos = cost_[:, 0:1]  # B*S x 1
            cost_neg = cost_[:, 1:]  # B*S x N

            cost_pos = cost_pos.unsqueeze(2)  # B*S x 1 x 1
            cost_neg = cost_neg.unsqueeze(1)  # B*S x 1 x N

            utils_misc.add_loss('forecast/mean_cost_pos', 0,
                                torch.mean(cost_pos), 0, summ_writer)
            utils_misc.add_loss('forecast/mean_cost_neg', 0,
                                torch.mean(cost_neg), 0, summ_writer)
            utils_misc.add_loss('forecast/mean_margin', 0,
                                torch.mean(cost_neg - cost_pos), 0,
                                summ_writer)

            xyz_pos = xyz_pos_.unsqueeze(2)  # B*S x 1 x 1 x 3
            xyz_neg = xyz_neg_.unsqueeze(1)  # B*S x 1 x N x 3
            dist = torch.norm(xyz_pos - xyz_neg, dim=3)  # B*S x 1 x N
            dist = dist / float(
                Z) * 5.0  # normalize for resolution, but upweight it a bit
            margin = F.relu(cost_pos - cost_neg + dist)
            margin = margin.reshape(B, S, N)
            # mean over time (in the paper this is a sum)
            margin = utils_basic.reduce_masked_mean(margin,
                                                    scorelist.repeat(1, 1, N),
                                                    dim=1)
            # max over the negatives
            maxmargin = torch.max(margin, dim=1)[0]  # B
            maxmargin_loss = torch.mean(maxmargin)
            total_loss = utils_misc.add_loss('forecast/maxmargin_loss',
                                             total_loss, maxmargin_loss,
                                             hyp.forecast_maxmargin_coeff,
                                             summ_writer)

            cost_neg = cost_neg.reshape(B, S, N)[0].detach().cpu().numpy()
            sampled_trajs_mem = sampled_trajs_mem.reshape(B, N, S, 3)[0:1]
            cost_neg = np.reshape(cost_neg, [S, N])
            cost_neg = np.sum(cost_neg, axis=0)
            inds = np.argsort(cost_neg, axis=0)

            for n in list(range(2)):

                xyzlist_e_mem = sampled_trajs_mem[0:1, inds[n]]
                xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X)
                # this is B x S x 3

                # if summ_writer.save_this and n==0:
                #     print('xyzlist_e_cam', xyzlist_e_cam[0:1])
                #     print('xyzlist_g_cam', xyzlist_cam[0:1])
                #     print('scorelist', scorelist[0:1])

                dist = torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2)
                # this is B x S
                meandist = utils_basic.reduce_masked_mean(
                    dist, scorelist[0:1].squeeze(2))
                utils_misc.add_loss('forecast/xyz_dist_%d' % n, 0, meandist, 0,
                                    summ_writer)
                # dist = torch.mean(torch.sum(torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2), dim=1))

                # mpe = torch.mean(torch.norm(xyzlist_cam[0:1,int(S/2)] - xyzlist_e_cam[0:1,int(S/2)], dim=1))
                # mpe = utils_basic.reduce_masked_mean(dist, scorelist[0:1])
                # utils_misc.add_loss('forecast/xyz_mpe_%d' % n, 0, dist, 0, summ_writer)

                # epe = torch.mean(torch.norm(xyzlist_cam[0:1,-1] - xyzlist_e_cam[0:1,-1], dim=1))
                # utils_misc.add_loss('forecast/xyz_epe_%d' % n, 0, dist, 0, summ_writer)

            if summ_writer.save_this:
                # plot the best and worst trajs
                # print('sorted costs:', cost_neg[inds])
                for n in list(range(2)):
                    ind = inds[n]
                    # print('plotting good traj with cost %.2f' % (cost_neg[ind]))
                    xyzlist_e_mem = sampled_trajs_mem[:, ind]
                    # this is 1 x S x 3
                    summ_writer.summ_traj_on_occ(
                        'forecast/best_sampled_traj%d' % n,
                        xyzlist_e_mem,
                        torch.max(occs[0:1], dim=1)[0],
                        # torch.zeros([1, 1, Z, Y, X]).float().cuda(),
                        already_mem=True,
                        sigma=1)

                for n in list(range(2)):
                    ind = inds[-(n + 1)]
                    # print('plotting bad traj with cost %.2f' % (cost_neg[ind]))
                    xyzlist_e_mem = sampled_trajs_mem[:, ind]
                    # this is 1 x S x 3
                    summ_writer.summ_traj_on_occ(
                        'forecast/worst_sampled_traj%d' % n,
                        xyzlist_e_mem,
                        torch.max(occs[0:1], dim=1)[0],
                        # torch.zeros([1, 1, Z, Y, X]).float().cuda(),
                        already_mem=True,
                        sigma=1)
        else:

            # use some timesteps as input
            feat_input = feats[:, :self.num_given].squeeze(2)
            # feat_input is B x self.num_given x ZZ x ZY x ZX

            ## regular bottle3D
            # vel_e = self.regressor(feat_input)

            ## sparse-invar bottle3D
            comp_mask = 1.0 - (feat_input == 0).all(dim=1,
                                                    keepdim=True).float()
            summ_writer.summ_feat('forecast/feat_input', feat_input, pca=False)
            summ_writer.summ_feat('forecast/feat_comp_mask',
                                  comp_mask,
                                  pca=False)
            vel_e = self.regressor(feat_input, comp_mask)

            vel_e = vel_e.reshape(B, self.num_need, 3)
            vel_g = xyzlist_cam[:,
                                self.num_given:] - xyzlist_cam[:,
                                                               self.num_given -
                                                               1:-1]

            xyzlist_e = torch.zeros_like(xyzlist_cam)
            xyzlist_g = torch.zeros_like(xyzlist_cam)
            for s in list(range(S)):
                # print('s = %d' % s)
                if s < self.num_given:
                    # print('grabbing from gt ind %s' % s)
                    xyzlist_e[:, s] = xyzlist_cam[:, s]
                    xyzlist_g[:, s] = xyzlist_cam[:, s]
                else:
                    # print('grabbing from s-self.num_given, which is ind %d' % (s-self.num_given))
                    xyzlist_e[:,
                              s] = xyzlist_e[:, s - 1] + vel_e[:, s -
                                                               self.num_given]
                    xyzlist_g[:,
                              s] = xyzlist_g[:, s - 1] + vel_g[:, s -
                                                               self.num_given]

        xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X)
        xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X)
        summ_writer.summ_traj_on_occ('forecast/traj_e',
                                     xyzlist_e_mem,
                                     torch.max(occs, dim=1)[0],
                                     already_mem=True,
                                     sigma=2)
        summ_writer.summ_traj_on_occ('forecast/traj_g',
                                     xyzlist_g_mem,
                                     torch.max(occs, dim=1)[0],
                                     already_mem=True,
                                     sigma=2)

        scorelist_here = scorelist[:, self.num_given:, 0]
        sql2 = torch.sum((vel_g - vel_e)**2, dim=2)

        ## yes weightmask
        weightmask = torch.arange(0,
                                  self.num_need,
                                  dtype=torch.float32,
                                  device=torch.device('cuda'))
        weightmask = torch.exp(-weightmask**(1. / 4))
        # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860,
        #         0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353
        weightmask = weightmask.reshape(1, self.num_need)
        l2_loss = utils_basic.reduce_masked_mean(sql2,
                                                 scorelist_here * weightmask)

        utils_misc.add_loss('forecast/l2_loss', 0, l2_loss, 0, summ_writer)

        # # no weightmask:
        # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here)

        # total_loss = utils_misc.add_loss('forecast/l2_loss', total_loss, l2_loss, hyp.forecast_l2_coeff, summ_writer)

        dist = torch.norm(xyzlist_e - xyzlist_g, dim=2)
        meandist = utils_basic.reduce_masked_mean(dist, scorelist[:, :, 0])
        utils_misc.add_loss('forecast/xyz_dist_0', 0, meandist, 0, summ_writer)

        l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here)
        # utils_misc.add_loss('forecast/vel_dist_noexp', 0, l2_loss, 0, summ_writer)
        total_loss = utils_misc.add_loss('forecast/l2_loss_noexp', total_loss,
                                         l2_loss_noexp, hyp.forecast_l2_coeff,
                                         summ_writer)

        return total_loss
Esempio n. 7
0
    def forward(self, clist_cam, energy_map, occ_mems, summ_writer):

        total_loss = torch.tensor(0.0).cuda()

        B, S, C, Z, Y, X = list(occ_mems.shape)
        B2, S, D = list(clist_cam.shape)
        assert (B == B2)

        traj_past = clist_cam[:, :self.T_past]
        traj_futu = clist_cam[:, self.T_past:]

        # just xz
        traj_past = torch.stack([traj_past[:, :, 0], traj_past[:, :, 2]],
                                dim=2)  # xz
        traj_futu = torch.stack([traj_futu[:, :, 0], traj_futu[:, :, 2]],
                                dim=2)  # xz

        feat = occ_mems[:, 0].permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X)
        mask = 1.0 - (feat == 0).all(dim=1, keepdim=True).float().cuda()
        halfgrid = utils_basic.meshgrid2D(B,
                                          int(Z / 2),
                                          int(X / 2),
                                          stack=True,
                                          norm=True).permute(0, 3, 1, 2)
        feat_map, _ = self.compressor(feat, mask, halfgrid)
        pred_map = self.conv2d(feat_map)
        # these are B x C x Z x X

        K = 12  # number of samples
        traj_past = traj_past.unsqueeze(0).repeat(K, 1, 1, 1)
        feat_map = feat_map.unsqueeze(0).repeat(K, 1, 1, 1, 1)
        pred_map = pred_map.unsqueeze(0).repeat(K, 1, 1, 1, 1)
        # to sample the K trajectories in parallel, we'll pack K onto the batch dim
        __p = lambda x: utils_basic.pack_seqdim(x, K)
        __u = lambda x: utils_basic.unpack_seqdim(x, K)
        traj_past_ = __p(traj_past)
        feat_map_ = __p(feat_map)
        pred_map_ = __p(pred_map)
        base_sample_ = torch.randn(K * B, self.T_futu, 2).cuda()
        traj_futu_e_ = self.compute_forward_mapping(feat_map_, pred_map_,
                                                    base_sample_, traj_past_)
        traj_futu_e = __u(traj_futu_e_)
        # this is K x B x T x 2

        # print('traj_futu_e', traj_futu_e.shape, traj_futu_e[0,0])
        if summ_writer.save_this:
            o = []
            for k in list(range(K)):
                o.append(
                    utils_improc.preprocess_color(
                        summ_writer.summ_traj_on_occ(
                            '',
                            utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]),
                                              Z, Y, X),
                            occ_mems[:, 0],
                            already_mem=True,
                            only_return=True)))
                summ_writer.summ_traj_on_occ(
                    'rponet/traj_futu_sample_%d' % k,
                    utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]), Z, Y,
                                      X),
                    occ_mems[:, 0],
                    already_mem=True)

            mean_vis = torch.max(torch.stack(o, dim=0), dim=0)[0]
            summ_writer.summ_rgb('rponet/traj_futu_e_mean', mean_vis)

            summ_writer.summ_traj_on_occ('rponet/traj_futu_g',
                                         utils_vox.Ref2Mem(
                                             self.add_fake_y(traj_futu), Z, Y,
                                             X),
                                         occ_mems[:, 0],
                                         already_mem=True)

        # forward loss: neg logprob of GT samples under the model
        # reverse loss: neg logprob of estim samples under the (approx) GT (i.e., spatial prior)
        forward_loss, reverse_loss = self.compute_loss(feat_map[0],
                                                       pred_map[0],
                                                       traj_past[0], traj_futu,
                                                       traj_futu_e, energy_map)
        total_loss = utils_misc.add_loss('rpo/forward_loss', total_loss,
                                         forward_loss, hyp.rpo2D_forward_coeff,
                                         summ_writer)
        total_loss = utils_misc.add_loss('rpo/reverse_loss', total_loss,
                                         reverse_loss, hyp.rpo2D_reverse_coeff,
                                         summ_writer)

        return total_loss
Esempio n. 8
0
    def forward(self, clist_cam, occs, summ_writer, vox_util, suffix=''):
        total_loss = torch.tensor(0.0).cuda()
        B, S, C, Z, Y, X = list(occs.shape)
        B2, S2, D = list(clist_cam.shape)
        assert (B == B2, S == S2)
        assert (D == 3)

        if summ_writer.save_this:
            summ_writer.summ_traj_on_occ('motioncost/actual_traj',
                                         clist_cam,
                                         occs[:, self.T_past],
                                         vox_util,
                                         sigma=2)

        __p = lambda x: utils_basic.pack_seqdim(x, B)
        __u = lambda x: utils_basic.unpack_seqdim(x, B)

        # occs_ = occs.reshape(B*S, C, Z, Y, X)
        occs_ = __p(occs)
        feats_ = occs_.permute(0, 1, 3, 2, 4).reshape(B * S, C * Y, Z, X)
        masks_ = 1.0 - (feats_ == 0).all(dim=1, keepdim=True).float().cuda()
        halfgrids_ = utils_basic.meshgrid2D(B * S,
                                            int(Z / 2),
                                            int(X / 2),
                                            stack=True,
                                            norm=True).permute(0, 3, 1, 2)
        # feats_ = torch.cat([feats_, grids_], dim=1)
        feats = __u(feats_)
        masks = __u(masks_)
        halfgrids = __u(halfgrids_)
        input_feats = feats[:, :self.T_past]
        input_masks = masks[:, :self.T_past]
        input_halfgrids = halfgrids[:, :self.T_past]
        dense_feats_, _ = self.densifier(__p(input_feats), __p(input_masks),
                                         __p(input_halfgrids))
        dense_feats = __u(dense_feats_)
        super_feat = dense_feats.reshape(B, self.T_past * self.dense_dim,
                                         int(Z / 2), int(X / 2))
        cost_maps = self.motioncoster(super_feat)
        cost_maps = F.interpolate(cost_maps, scale_factor=4, mode='bilinear')
        # this is B x T_futu x Z x X
        cost_maps = cost_maps.clamp(-1000,
                                    1000)  # raquel says this adds stability
        summ_writer.summ_histogram('motioncost/cost_maps_hist', cost_maps)
        summ_writer.summ_oneds('motioncost/cost_maps',
                               torch.unbind(cost_maps.unsqueeze(2), dim=1))

        # next i need to sample some trajectories

        N = hyp.motioncost_num_negs
        sampled_trajs_cam = self.sample_trajs(N, clist_cam)
        # this is B x N x S x 3

        if summ_writer.save_this:
            # for n in list(range(np.min([N, 3]))):
            #     # this is 1 x S x 3
            #     summ_writer.summ_traj_on_occ('motioncost/sample%d_clist' % n,
            #                                  sampled_trajs_cam[0, n].unsqueeze(0),
            #                                  occs[:,self.T_past],
            #                                  # torch.max(occs, dim=1)[0],
            #                                  # torch.zeros([1, 1, Z, Y, X]).float().cuda(),
            #                                  already_mem=False)
            o = []
            for n in list(range(N)):
                o.append(
                    utils_improc.preprocess_color(
                        summ_writer.summ_traj_on_occ(
                            '',
                            sampled_trajs_cam[0, n].unsqueeze(0),
                            occs[0:1, self.T_past],
                            vox_util,
                            only_return=True,
                            sigma=0.5)))
            summ_vis = torch.max(torch.stack(o, dim=0), dim=0)[0]
            summ_writer.summ_rgb('motioncost/all_sampled_trajs', summ_vis)

        # smooth loss
        cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X)
        dz, dx = gradient2D(cost_maps_, absolute=True)
        dt = torch.abs(cost_maps[:, 1:] - cost_maps[:, 0:-1])
        smooth_spatial = torch.mean(dx + dz, dim=1, keepdims=True)
        smooth_time = torch.mean(dt, dim=1, keepdims=True)
        summ_writer.summ_oned('motioncost/smooth_loss_spatial', smooth_spatial)
        summ_writer.summ_oned('motioncost/smooth_loss_time', smooth_time)
        smooth_loss = torch.mean(smooth_spatial) + torch.mean(smooth_time)
        total_loss = utils_misc.add_loss('motioncost/smooth_loss', total_loss,
                                         smooth_loss,
                                         hyp.motioncost_smooth_coeff,
                                         summ_writer)

        # def clamp_xyz(xyz, X, Y, Z):
        #     x, y, z = torch.unbind(xyz, dim=-1)
        #     x = x.clamp(0, X)
        #     y = x.clamp(0, Y)
        #     z = x.clamp(0, Z)
        #     # if zero_y:
        #     #     y = torch.zeros_like(y)
        #     xyz = torch.stack([x,y,z], dim=-1)
        #     return xyz
        def clamp_xz(xz, X, Z):
            x, z = torch.unbind(xz, dim=-1)
            x = x.clamp(0, X)
            z = x.clamp(0, Z)
            xz = torch.stack([x, z], dim=-1)
            return xz

        clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # this is B x S x 3

        # sampled_trajs_cam is B x N x S x 3
        sampled_trajs_cam_ = sampled_trajs_cam.reshape(B, N * S, 3)
        sampled_trajs_mem_ = utils_vox.Ref2Mem(sampled_trajs_cam_, Z, Y, X)
        sampled_trajs_mem = sampled_trajs_mem_.reshape(B, N, S, 3)
        # this is B x N x S x 3

        xyz_pos_ = clist_mem[:, self.T_past:].reshape(B * self.T_futu, 1, 3)
        xyz_neg_ = sampled_trajs_mem[:, :,
                                     self.T_past:].permute(0, 2, 1, 3).reshape(
                                         B * self.T_futu, N, 3)
        # get rid of y
        xz_pos_ = torch.stack([xyz_pos_[:, :, 0], xyz_pos_[:, :, 2]], dim=2)
        xz_neg_ = torch.stack([xyz_neg_[:, :, 0], xyz_neg_[:, :, 2]], dim=2)
        xz_ = torch.cat([xz_pos_, xz_neg_], dim=1)
        xz_ = clamp_xz(xz_, X, Z)
        cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X)
        cost_ = utils_samp.bilinear_sample2D(cost_maps_, xz_[:, :, 0],
                                             xz_[:, :, 1]).squeeze(1)
        # cost is B*T_futu x 1+N
        cost_pos = cost_[:, 0:1]  # B*T_futu x 1
        cost_neg = cost_[:, 1:]  # B*T_futu x N

        cost_pos = cost_pos.unsqueeze(2)  # B*T_futu x 1 x 1
        cost_neg = cost_neg.unsqueeze(1)  # B*T_futu x 1 x N

        utils_misc.add_loss('motioncost/mean_cost_pos', 0,
                            torch.mean(cost_pos), 0, summ_writer)
        utils_misc.add_loss('motioncost/mean_cost_neg', 0,
                            torch.mean(cost_neg), 0, summ_writer)
        utils_misc.add_loss('motioncost/mean_margin', 0,
                            torch.mean(cost_neg - cost_pos), 0, summ_writer)

        xz_pos = xz_pos_.unsqueeze(2)  # B*T_futu x 1 x 1 x 3
        xz_neg = xz_neg_.unsqueeze(1)  # B*T_futu x 1 x N x 3
        dist = torch.norm(xz_pos - xz_neg, dim=3)  # B*T_futu x 1 x N
        dist = dist / float(
            Z) * 5.0  # normalize for resolution, but upweight it a bit
        margin = F.relu(cost_pos - cost_neg + dist)
        margin = margin.reshape(B, self.T_futu, N)
        # mean over time (in the paper this is a sum)
        margin = torch.mean(margin, dim=1)
        # max over the negatives
        maxmargin = torch.max(margin, dim=1)[0]  # B
        maxmargin_loss = torch.mean(maxmargin)
        total_loss = utils_misc.add_loss('motioncost/maxmargin_loss',
                                         total_loss, maxmargin_loss,
                                         hyp.motioncost_maxmargin_coeff,
                                         summ_writer)

        # now let's see some top k
        # we'll do this for the first el of the batch
        cost_neg = cost_neg.reshape(B, self.T_futu,
                                    N)[0].detach().cpu().numpy()
        futu_mem = sampled_trajs_mem[:, :, self.T_past:].reshape(
            B, N, self.T_futu, 3)[0:1]
        cost_neg = np.reshape(cost_neg, [self.T_futu, N])
        cost_neg = np.sum(cost_neg, axis=0)
        inds = np.argsort(cost_neg, axis=0)

        for n in list(range(2)):
            xyzlist_e_mem = futu_mem[0:1, inds[n]]
            xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X)
            # this is B x S x 3

            if summ_writer.save_this and n == 0:
                print('xyzlist_e_cam', xyzlist_e_cam[0:1])
                print('xyzlist_g_cam', clist_cam[0:1, self.T_past:])

            dist = torch.norm(clist_cam[0:1, self.T_past:] -
                              xyzlist_e_cam[0:1],
                              dim=2)
            # this is B x T_futu
            meandist = torch.mean(dist)
            utils_misc.add_loss('motioncost/xyz_dist_%d' % n, 0, meandist, 0,
                                summ_writer)

        if summ_writer.save_this:
            # plot the best and worst trajs
            # print('sorted costs:', cost_neg[inds])
            for n in list(range(2)):
                ind = inds[n]
                print('plotting good traj with cost %.2f' % (cost_neg[ind]))
                xyzlist_e_mem = sampled_trajs_mem[:, ind]
                # this is 1 x S x 3
                summ_writer.summ_traj_on_occ('motioncost/best_sampled_traj%d' %
                                             n,
                                             xyzlist_e_mem[0:1],
                                             occs[0:1, self.T_past],
                                             vox_util,
                                             already_mem=True,
                                             sigma=2)

            for n in list(range(2)):
                ind = inds[-(n + 1)]
                print('plotting bad traj with cost %.2f' % (cost_neg[ind]))
                xyzlist_e_mem = sampled_trajs_mem[:, ind]
                # this is 1 x S x 3
                summ_writer.summ_traj_on_occ(
                    'motioncost/worst_sampled_traj%d' % n,
                    xyzlist_e_mem[0:1],
                    occs[0:1, self.T_past],
                    vox_util,
                    already_mem=True,
                    sigma=2)

        # xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X)
        # xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X)
        # summ_writer.summ_traj_on_occ('motioncost/traj_e',
        #                              xyzlist_e_mem,
        #                              torch.max(occs, dim=1)[0],
        #                              already_mem=True,
        #                              sigma=2)
        # summ_writer.summ_traj_on_occ('motioncost/traj_g',
        #                              xyzlist_g_mem,
        #                              torch.max(occs, dim=1)[0],
        #                              already_mem=True,
        #                              sigma=2)

        # scorelist_here = scorelist[:,self.num_given:,0]
        # sql2 = torch.sum((vel_g-vel_e)**2, dim=2)

        # ## yes weightmask
        # weightmask = torch.arange(0, self.num_need, dtype=torch.float32, device=torch.device('cuda'))
        # weightmask = torch.exp(-weightmask**(1./4))
        # # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860,
        # #         0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353
        # weightmask = weightmask.reshape(1, self.num_need)
        # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here * weightmask)

        # utils_misc.add_loss('motioncost/l2_loss', 0, l2_loss, 0, summ_writer)

        # # # no weightmask:
        # # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here)

        # # total_loss = utils_misc.add_loss('motioncost/l2_loss', total_loss, l2_loss, hyp.motioncost_l2_coeff, summ_writer)

        # dist = torch.norm(xyzlist_e - xyzlist_g, dim=2)
        # meandist = utils_basic.reduce_masked_mean(dist, scorelist[:,:,0])
        # utils_misc.add_loss('motioncost/xyz_dist_0', 0, meandist, 0, summ_writer)

        # l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here)
        # # utils_misc.add_loss('motioncost/vel_dist_noexp', 0, l2_loss, 0, summ_writer)
        # total_loss = utils_misc.add_loss('motioncost/l2_loss_noexp', total_loss, l2_loss_noexp, hyp.motioncost_l2_coeff, summ_writer)

        return total_loss
Esempio n. 9
0
    def forward(self, feed, moc_init_done=False, debug=False):
        summ_writer = utils_improc.Summ_writer(
            writer = feed['writer'],
            global_step = feed['global_step'],
            set_name= feed['set_name'],
            fps=8)

        writer = feed['writer']
        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        ### ... All things sensor ... ###
        sensor_rgbs = feed['sensor_imgs']
        sensor_depths = feed['sensor_depths']
        center_sensor_H, center_sensor_W = sensor_depths[0][0].shape[-1] // 2, sensor_depths[0][0].shape[-2] // 2
        ### ... All things sensor end ... ###

        # 1. Form the memory tensor using the feat net and visual images.
        # check what all do you need for this and create only those things

        ##  .... Input images ....  ##
        rgb_camRs = feed['rgb_camRs']
        rgb_camXs = feed['rgb_camXs']
        ##  .... Input images end ....  ##

        ## ... Hyperparams ... ##
        B, H, W, V, S = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S
        __p = lambda x: pack_seqdim(x, B)
        __u = lambda x: unpack_seqdim(x, B)
        PH, PW = hyp.PH, hyp.PW
        Z, Y, X = hyp.Z, hyp.Y, hyp.X
        Z2, Y2, X2 = int(Z/2), int(Y/2), int(X/2)
        ## ... Hyperparams end ... ##

        ## .... VISUAL TRANSFORMS BEGIN .... ##
        pix_T_cams = feed['pix_T_cams']
        pix_T_cams_ = __p(pix_T_cams)
        origin_T_camRs = feed['origin_T_camRs']
        origin_T_camRs_ = __p(origin_T_camRs)
        origin_T_camXs = feed['origin_T_camXs']
        origin_T_camXs_ = __p(origin_T_camXs)
        camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse(
            origin_T_camRs_), origin_T_camXs_)
        camXs_T_camRs_ = utils_geom.safe_inverse(camRs_T_camXs_)
        camRs_T_camXs = __u(camRs_T_camXs_)
        camXs_T_camRs = __u(camXs_T_camRs_)
        pix_T_cams_ = utils_geom.pack_intrinsics(pix_T_cams_[:, 0, 0], pix_T_cams_[:, 1, 1], pix_T_cams_[:, 0, 2],
            pix_T_cams_[:, 1, 2])
        pix_T_camRs_ = torch.matmul(pix_T_cams_, camXs_T_camRs_)
        pix_T_camRs = __u(pix_T_camRs_)
        ## ... VISUAL TRANSFORMS END ... ##

        ## ... SENSOR TRANSFORMS BEGIN ... ##
        sensor_origin_T_camXs = feed['sensor_extrinsics']
        sensor_origin_T_camXs_ = __p(sensor_origin_T_camXs)
        sensor_origin_T_camRs = feed['sensor_origin_T_camRs']
        sensor_origin_T_camRs_ = __p(sensor_origin_T_camRs)
        sensor_camRs_T_origin_ = utils_geom.safe_inverse(sensor_origin_T_camRs_)

        sensor_camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse(
            sensor_origin_T_camRs_), sensor_origin_T_camXs_)
        sensor_camXs_T_camRs_ = utils_geom.safe_inverse(sensor_camRs_T_camXs_)

        sensor_camRs_T_camXs = __u(sensor_camRs_T_camXs_)
        sensor_camXs_T_camRs = __u(sensor_camXs_T_camRs_)

        sensor_pix_T_cams = feed['sensor_intrinsics']
        sensor_pix_T_cams_ = __p(sensor_pix_T_cams)
        sensor_pix_T_cams_ = utils_geom.pack_intrinsics(sensor_pix_T_cams_[:, 0, 0], sensor_pix_T_cams_[:, 1, 1],
            sensor_pix_T_cams_[:, 0, 2], sensor_pix_T_cams_[:, 1, 2])
        sensor_pix_T_camRs_ = torch.matmul(sensor_pix_T_cams_, sensor_camXs_T_camRs_)
        sensor_pix_T_camRs = __u(sensor_pix_T_camRs_)
        ## .... SENSOR TRANSFORMS END .... ##

        ## .... Visual Input point clouds .... ##
        xyz_camXs = feed['xyz_camXs']
        xyz_camXs_ = __p(xyz_camXs)
        xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, xyz_camXs_)  # (40, 4, 4) (B*S, N, 3)
        xyz_camRs = __u(xyz_camRs_)
        assert all([torch.allclose(xyz_camR, inp_xyz_camR) for xyz_camR, inp_xyz_camR in zip(
            xyz_camRs, feed['xyz_camRs']
        )]), "computation of xyz_camR here and those computed in input do not match"
        ## .... Visual Input point clouds end .... ##

        ## ... Sensor input point clouds ... ##
        sensor_xyz_camXs = feed['sensor_xyz_camXs']
        sensor_xyz_camXs_ = __p(sensor_xyz_camXs)
        sensor_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_xyz_camXs_)
        sensor_xyz_camRs = __u(sensor_xyz_camRs_)
        assert all([torch.allclose(sensor_xyz, inp_sensor_xyz) for sensor_xyz, inp_sensor_xyz in zip(
            sensor_xyz_camRs, feed['sensor_xyz_camRs']
        )]), "the sensor_xyz_camRs computed in forward do not match those computed in input"

        ## ... visual occupancy computation voxelize the pointcloud from above ... ##
        occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X)
        occXs_ = utils_vox.voxelize_xyz(xyz_camXs_, Z, Y, X)
        occRs_half_ = utils_vox.voxelize_xyz(xyz_camRs_, Z2, Y2, X2)
        occXs_half_ = utils_vox.voxelize_xyz(xyz_camXs_, Z2, Y2, X2)
        ## ... visual occupancy computation end ... NOTE: no unpacking ##

        ## .. visual occupancy computation for sensor inputs .. ##
        sensor_occRs_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z, Y, X)
        sensor_occXs_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z, Y, X)
        sensor_occRs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z2, Y2, X2)
        sensor_occXs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z2, Y2, X2)

        ## ... unproject rgb images ... ##
        unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_camRs_)
        unpXs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_cams_)
        ## ... unproject rgb finish ... NOTE: no unpacking ##

        ## ... Make depth images ... ##
        depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(pix_T_cams_, xyz_camXs_, H, W)
        dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_, pix_T_cams_)
        dense_xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, dense_xyz_camXs_)
        inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y, X).float()
        inbound_camXs_ = torch.reshape(inbound_camXs_, [B*S, 1, H, W])
        valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_)
        ## ... Make depth images ... ##

        ## ... Make sensor depth images ... ##
        sensor_depth_camXs_, sensor_valid_camXs_ = utils_geom.create_depth_image(sensor_pix_T_cams_,
            sensor_xyz_camXs_, H, W)
        sensor_dense_xyz_camXs_ = utils_geom.depth2pointcloud(sensor_depth_camXs_, sensor_pix_T_cams_)
        sensor_dense_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_dense_xyz_camXs_)
        sensor_inbound_camXs_ = utils_vox.get_inbounds(sensor_dense_xyz_camRs_, Z, Y, X).float()
        sensor_inbound_camXs_ = torch.reshape(sensor_inbound_camXs_, [B*hyp.sensor_S, 1, H, W])
        sensor_valid_camXs = __u(sensor_valid_camXs_) * __u(sensor_inbound_camXs_)
        ### .. Done making sensor depth images .. ##

        ### ... Sanity check ... Write to tensorboard ... ###
        summ_writer.summ_oneds('2D_inputs/depth_camXs', torch.unbind(__u(depth_camXs_), dim=1))
        summ_writer.summ_oneds('2D_inputs/valid_camXs', torch.unbind(valid_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(rgb_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camRs', torch.unbind(rgb_camRs, dim=1))
        summ_writer.summ_occs('3d_inputs/occXs', torch.unbind(__u(occXs_), dim=1), reduce_axes=[2])
        summ_writer.summ_unps('3d_inputs/unpXs', torch.unbind(__u(unpXs_), dim=1),\
            torch.unbind(__u(occXs_), dim=1))

        # A different approach for viewing occRs of sensors
        sensor_occRs = __u(sensor_occRs_)
        vis_sensor_occRs = torch.max(sensor_occRs, dim=1, keepdim=True)[0]
        # summ_writer.summ_occs('3d_inputs/sensor_occXs', torch.unbind(__u(sensor_occXs_), dim=1),
        #     reduce_axes=[2])
        summ_writer.summ_occs('3d_inputs/sensor_occRs', torch.unbind(vis_sensor_occRs, dim=1), reduce_axes=[2])

        ### ... code for visualizing sensor depths and sensor rgbs ... ###
        # summ_writer.summ_oneds('2D_inputs/depths_sensor', torch.unbind(sensor_depths, dim=1))
        # summ_writer.summ_rgbs('2D_inputs/rgbs_sensor', torch.unbind(sensor_rgbs, dim=1))
        # summ_writer.summ_oneds('2D_inputs/validXs_sensor', torch.unbind(sensor_valid_camXs, dim=1))

        if summ_writer.save_this:
            unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, matmul2(pix_T_cams_, camXs_T_camRs_))
            unpRs = __u(unpRs_)
            occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X)
            summ_writer.summ_occs('3d_inputs/occRs', torch.unbind(__u(occRs_), dim=1), reduce_axes=[2])
            summ_writer.summ_unps('3d_inputs/unpRs', torch.unbind(unpRs, dim=1),\
                torch.unbind(__u(occRs_), dim=1))
        ### ... Sanity check ... Writing to tensoboard complete ... ###
        results = list()

        mask_ = None
        ### ... Visual featnet part .... ###
        if hyp.do_feat:
            featXs_input = torch.cat([__u(occXs_), __u(occXs_)*__u(unpXs_)], dim=2)  # B, S, 4, H, W, D
            featXs_input_ = __p(featXs_input)

            freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), occXs_half_)
            freeXs = __u(freeXs_)
            visXs = torch.clamp(__u(occXs_half_) + freeXs, 0.0, 1.0)

            if type(mask_) != type(None):
                assert(list(mask_.shape)[2:5] == list(featXs_input.shape)[2:5])
            featXs_, validXs_, _ = self.featnet(featXs_input_, summ_writer, mask=occXs_)
            # total_loss += feat_loss  # Note no need of loss

            validXs, featXs = __u(validXs_), __u(featXs_) # unpacked into B, S, C, D, H, W
            # bring everything to ref_frame
            validRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, validXs)
            visRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, visXs)
            featRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, featXs)  # This is now in memory coordinates

            emb3D_e = torch.mean(featRs[:, 1:], dim=1)  # context, or the features of the scene
            emb3D_g = featRs[:, 0]  # this is to predict, basically I will pass emb3D_e as input and hope to predict emb3D_g
            vis3D_e = torch.max(validRs[:, 1:], dim=1)[0] * torch.max(visRs[:, 1:], dim=1)[0]
            vis3D_g = validRs[:, 0] * visRs[:, 0]

            #### ... I do not think I need this ... ####
            results = {}
        #     # if hyp.do_eval_recall:
        #     #     results['emb3D_e'] = emb3D_e
        #     #     results['emb3D_g'] = emb3D_g
        #     #### ... Check if you need the above

            summ_writer.summ_feats('3D_feats/featXs_input', torch.unbind(featXs_input, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/featRs_output', torch.unbind(featRs, dim=1), pca=True)
            summ_writer.summ_feats('3D_feats/validRs', torch.unbind(validRs, dim=1), pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_e', vis3D_e, pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_g', vis3D_g, pca=False)

            # I need to aggregate the features and detach to prevent the backward pass on featnet
            featRs = torch.mean(featRs, dim=1)
            featRs = featRs.detach()
            #  ... HERE I HAVE THE VISUAL FEATURE TENSOR ... WHICH IS MADE USING 5 EVENLY SPACED VIEWS #

        # FOR THE TOUCH PART, I HAVE THE OCC and THE AIM IS TO PREDICT FEATURES FROM THEM #
        if hyp.do_touch_feat:
            # 1. Pass all the sensor depth images through the backbone network
            input_sensor_depths = __p(sensor_depths)
            sensor_features_ = self.backbone_2D(input_sensor_depths)

            # should normalize these feature tensors
            sensor_features_ = l2_normalize(sensor_features_, dim=1)

            sensor_features = __u(sensor_features_)
            assert torch.allclose(torch.norm(sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\
                "normalization has no effect on you huh."

            if hyp.do_eval_recall:
              results['sensor_features'] = sensor_features_
              results['sensor_depths'] = input_sensor_depths
              results['object_img'] = rgb_camRs
              results['sensor_imgs'] = __p(sensor_rgbs)

            # if moco is used do the same procedure as above but with a different network #
            if hyp.do_moc or hyp.do_eval_recall:
                # 1. Pass all the sensor depth images through the key network
                key_input_sensor_depths = copy.deepcopy(__p(sensor_depths)) # bx1024x1x16x16->(2048x1x16x16)
                self.key_touch_featnet.eval()
                with torch.no_grad():
                    key_sensor_features_ = self.key_touch_featnet(key_input_sensor_depths)

                key_sensor_features_ = l2_normalize(key_sensor_features_, dim=1)
                key_sensor_features = __u(key_sensor_features_)
                assert torch.allclose(torch.norm(key_sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\
                    "normalization has no effect on you huh."

        # doing the same procedure for moco but with a different network end #

        # do you want to do metric learning voxel point based using visual features and sensor features
        if hyp.do_touch_embML and not hyp.do_touch_forward:
            # trial 1: I do not pass the above obtained features through some encoder decoder in 3d
            # So compute the location is ref_frame which the center of these depth images will occupy
            # at all of these locations I will sample the from the visual tensor. It forms the positive pairs
            # negatives are simply everything except the positive
            sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_x = sensor_depths_centers_x.cuda()
            sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_y = sensor_depths_centers_y.cuda()
            sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W]

            # Next use Pixels2Camera to unproject all of these together.
            # merge the batch and the sequence dimension
            sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1)  # BxHxW as required by Pixels2Camera
            sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1)
            sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1)

            fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_)
            sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y,
                sensor_depths_centers_z, fx, fy, x0, y0)

            # finally use apply4x4 to get the locations in ref_cam
            sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_)

            # NOTE: convert them to memory coordinates, the name is xyz so I presume it returns xyz but talk to ADAM
            sensor_depths_centers_in_mem_ = utils_vox.Ref2Mem(sensor_depths_centers_in_ref_cam_, Z2, Y2, X2)
            sensor_depths_centers_in_mem = sensor_depths_centers_in_mem_.reshape(hyp.B, hyp.sensor_S, -1)

            if debug:
                print('assert that you are not entering here')
                from IPython import embed; embed()
                # form a (0, 1) volume here at these locations and see if it resembles a cup
                dim1 = X2 * Y2 * Z2
                dim2 = X2 * Y2
                dim3 = X2
                binary_voxel_grid = torch.zeros((hyp.B, X2, Y2, Z2))
                # NOTE: Z is the leading dimension
                rounded_idxs = torch.round(sensor_depths_centers_in_mem)
                flat_idxs = dim2 * rounded_idxs[0, :, 0] + dim3 * rounded_idxs[0, :, 1] + rounded_idxs[0, :, 2]
                flat_idxs1 = dim2 * rounded_idxs[1, :, 0] + dim3 * rounded_idxs[1, :, 1] + rounded_idxs[1, :, 2]
                flat_idxs1 = flat_idxs1 + dim1
                flat_idxs1 = flat_idxs1.long()
                flat_idxs = flat_idxs.long()

                flattened_grid = binary_voxel_grid.flatten()
                flattened_grid[flat_idxs] = 1.
                flattened_grid[flat_idxs1] = 1.

                binary_voxel_grid = flattened_grid.view(B, X2, Y2, Z2)

                assert binary_voxel_grid[0].sum() == len(torch.unique(flat_idxs)), "some indexes are missed here"
                assert binary_voxel_grid[1].sum() == len(torch.unique(flat_idxs1)), "some indexes are missed here"

                # o3d.io.write_voxel_grid("forward_pass_save/grid0.ply", binary_voxel_grid[0])
                # o3d.io.write_voxel_grid("forward_pass_save/grid1.ply", binary_voxel_grid[0])
                # need to save these voxels
                save_voxel(binary_voxel_grid[0].cpu().numpy(), "forward_pass_save/grid0.binvox")
                save_voxel(binary_voxel_grid[1].cpu().numpy(), "forward_pass_save/grid1.binvox")
                from IPython import embed; embed()

            # use grid sample to get the visual touch tensor at these locations, NOTE: visual tensor features shape is (B, C, N)
            visual_tensor_features = utils_samp.bilinear_sample3D(featRs, sensor_depths_centers_in_mem[:, :, 0],
                sensor_depths_centers_in_mem[:, :, 1], sensor_depths_centers_in_mem[:, :, 2])
            visual_feature_tensor = visual_tensor_features.permute(0, 2, 1)
            # pack it
            visual_feature_tensor_ = __p(visual_feature_tensor)
            C = list(visual_feature_tensor.shape)[-1]
            print('C=', C)

            # do the metric learning this is the same as before.
            # the code is basically copied from embnet3d.py but some changes are being made very minor
            emb_vec = torch.stack((sensor_features_, visual_feature_tensor_), dim=1).view(B*self.num_samples*self.batch_k, C)
            y = torch.stack([torch.range(0,self.num_samples*B-1), torch.range(0,self.num_samples*B-1)], dim=1).view(self.num_samples*B*self.batch_k)
            a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec)

            # I need to write my own version of margin loss since the negatives and anchors may not be same dim
            d_ap = torch.sqrt(torch.sum((positives - anchors)**2, dim=1) + 1e-8)
            pos_loss = torch.clamp(d_ap - beta + self._margin, min=0.0)

            # TODO: expand the dims of anchors and tile them and compute the negative loss

            # do the pair count where you average by contributors only

            # this is your total loss


            # Further idea is to check what volumetric locations do each of the depth images corresponds to
            # unproject the entire depth image and convert to ref. and then sample.

        if hyp.do_touch_forward:
            ## ... Begin code for getting crops from visual memory ... ##
            sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_x = sensor_depths_centers_x.cuda()
            sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S))
            sensor_depths_centers_y = sensor_depths_centers_y.cuda()
            sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W]

            # Next use Pixels2Camera to unproject all of these together.
            # merge the batch and the sequence dimension
            sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1)
            sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1)
            sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1)

            fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_)
            sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y,
                sensor_depths_centers_z, fx, fy, x0, y0)
            sensor_depths_centers_in_world_ = utils_geom.apply_4x4(sensor_origin_T_camXs_, sensor_depths_centers_in_camXs_)  # not used by the algorithm
            ## this will be later used for visualization hence saving it here for now
            sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_)  # not used by the algorithm

            sensor_depths_centers_in_camXs = __u(sensor_depths_centers_in_camXs_).squeeze(2)

            # There has to be a better way to do this, for each of the cameras in the batch I want a box of size (ch, cw, cd)
            # TODO: rotation is the deviation of the box from the axis aligned do I want this
            tB, tN, _ = list(sensor_depths_centers_in_camXs.shape)  # 2, 512, _
            boxlist = torch.zeros(tB, tN, 9)  # 2, 512, 9
            boxlist[:, :, :3] = sensor_depths_centers_in_camXs  # this lies on the object
            boxlist[:, :, 3:6] = torch.FloatTensor([hyp.contextW, hyp.contextH, hyp.contextD])

            # convert the boxlist to lrtlist and to cuda
            # the rt here transforms the from box coordinates to camera coordinates
            box_lrtlist = utils_geom.convert_boxlist_to_lrtlist(boxlist)

            # Now I will use crop_zoom_from_mem functionality to get the features in each of the boxes
            # I will do it for each of the box separately as required by the api
            context_grid_list = list()
            for m in range(box_lrtlist.shape[1]):
                curr_box = box_lrtlist[:, m, :]
                context_grid = utils_vox.crop_zoom_from_mem(featRs, curr_box, 8, 8, 8,
                    sensor_camRs_T_camXs[:, m, :, :])
                context_grid_list.append(context_grid)
            context_grid_list = torch.stack(context_grid_list, dim=1)
            context_grid_list_ = __p(context_grid_list)
            ## ... till here I believe I have not introduced any randomness, so the points are still in
            ## ... End code for getting crops around this center of certain height, width and depth ... ##

            ## ... Begin code for passing the context grid through 3D CNN to obtain a vector ... ##
            sensor_cam_locs = feed['sensor_locs']  # these are in origin coordinates
            sensor_cam_quats = feed['sensor_quats'] # this too in in world_coordinates
            sensor_cam_locs_ = __p(sensor_cam_locs)
            sensor_cam_quats_ = __p(sensor_cam_quats)
            sensor_cam_locs_in_R_ = utils_geom.apply_4x4(sensor_camRs_T_origin_, sensor_cam_locs_.unsqueeze(1)).squeeze(1)
            # TODO TODO TODO confirm that this is right? TODO TODO TODO
            get_r_mat = lambda cam_quat: transformations.quaternion_matrix_py(cam_quat)
            rot_mat_Xs_ = torch.from_numpy(np.stack(list(map(get_r_mat, sensor_cam_quats_.cpu().numpy())))).to(sensor_cam_locs_.device).float()
            rot_mat_Rs_ = torch.bmm(sensor_camRs_T_origin_, rot_mat_Xs_)
            get_quat = lambda r_mat: transformations.quaternion_from_matrix_py(r_mat)
            sensor_quats_in_R_ = torch.from_numpy(np.stack(list(map(get_quat, rot_mat_Rs_.cpu().numpy())))).to(sensor_cam_locs_.device).float()

            pred_features_ = self.context_net(context_grid_list_,\
                sensor_cam_locs_in_R_, sensor_quats_in_R_)

            # normalize
            pred_features_ = l2_normalize(pred_features_, dim=1)
            pred_features = __u(pred_features_)

            # if doing moco I have to pass the inputs through the key(slow) network as well #
            if hyp.do_moc or hyp.do_eval_recall:
                key_context_grid_list_ = copy.deepcopy(context_grid_list_)
                key_sensor_cam_locs_in_R_ = copy.deepcopy(sensor_cam_locs_in_R_)
                key_sensor_quats_in_R_ = copy.deepcopy(sensor_quats_in_R_)
                self.key_context_net.eval()
                with torch.no_grad():
                    key_pred_features_ = self.key_context_net(key_context_grid_list_,\
                        key_sensor_cam_locs_in_R_, key_sensor_quats_in_R_)

                # normalize, normalization is very important why though
                key_pred_features_ = l2_normalize(key_pred_features_, dim=1)
                key_pred_features = __u(key_pred_features_)
            # end passing of the input through the slow network this is necessary for moco #
            ## ... End code for passing the context grid through 3D CNN to obtain a vector ... ##

        ## ... Begin code for doing metric learning between pred_features and sensor features ... ##
        # 1. Subsample both based on the number of positive samples
        if hyp.do_touch_embML:
            assert(hyp.do_touch_forward)
            assert(hyp.do_touch_feat)
            perm = torch.randperm(len(pred_features_))  ## 1024
            chosen_sensor_feats_ = sensor_features_[perm[:self.num_pos_samples*hyp.B]]
            chosen_pred_feats_ = pred_features_[perm[:self.num_pos_samples*B]]

            # 2. form the emb_vec and get pos and negative samples for the batch
            emb_vec = torch.stack((chosen_sensor_feats_, chosen_pred_feats_), dim=1).view(hyp.B*self.num_pos_samples*self.batch_k, -1)
            y = torch.stack([torch.range(0, self.num_pos_samples*B-1), torch.range(0, self.num_pos_samples*B-1)],\
                dim=1).view(B*self.num_pos_samples*self.batch_k) # (0, 0, 1, 1, ..., 255, 255)

            a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec)

            # 3. Compute the loss, ML loss and the l2 distance betwee the embeddings
            margin_loss, _ = self.criterion(anchors, positives, negatives, self.beta, y[a_indices])
            total_loss = utils_misc.add_loss('embtouch/emb_touch_ml_loss', total_loss, margin_loss,
                hyp.emb_3D_ml_coeff, summ_writer)

            # the l2 loss between the embeddings
            l2_loss = torch.nn.functional.mse_loss(chosen_sensor_feats_, chosen_pred_feats_)
            total_loss = utils_misc.add_loss('embtouch/emb_l2_loss', total_loss, l2_loss,
                hyp.emb_3D_l2_coeff, summ_writer)
        ## ... End code for doing metric learning between pred_features and sensor_features ... ##

        ## ... Begin code for doing moc inspired ML between pred_features and sensor_features ... ##
        if hyp.do_moc and moc_init_done:
            moc_loss = self.moc_ml_net(sensor_features_, key_sensor_features_,\
                pred_features_, key_pred_features_, summ_writer)
            total_loss += moc_loss
        ## ... End code for doing moc inspired ML between pred_features and sensor_feature ... ##

        ## ... add code for filling up results needed for eval recall ... ##
        if hyp.do_eval_recall and moc_init_done:
            results['context_features'] = pred_features_
            results['sensor_depth_centers_in_world'] = sensor_depths_centers_in_world_
            results['sensor_depths_centers_in_ref_cam'] = sensor_depths_centers_in_ref_cam_
            results['object_name'] = feed['object_name']

            # I will do precision recall here at different recall values and summarize it using tensorboard
            recalls = [1, 5, 10, 50, 100, 200]
            # also should not include any gradients because of this
            # fast_sensor_emb_e = sensor_features_
            # fast_context_emb_e = pred_features_
            # slow_sensor_emb_g = key_sensor_features_
            # slow_context_emb_g = key_context_features_
            fast_sensor_emb_e = sensor_features_.clone().detach()
            fast_context_emb_e = pred_features_.clone().detach()

            # I will do multiple eval recalls here
            slow_sensor_emb_g = key_sensor_features_.clone().detach()
            slow_context_emb_g = key_pred_features_.clone().detach()

            # assuming the above thing goes well
            fast_sensor_emb_e = fast_sensor_emb_e.cpu().numpy()
            fast_context_emb_e = fast_context_emb_e.cpu().numpy()
            slow_sensor_emb_g = slow_sensor_emb_g.cpu().numpy()
            slow_context_emb_g = slow_context_emb_g.cpu().numpy()

            # now also move the vis to numpy and plot it using matplotlib
            vis_e = __p(sensor_rgbs)
            vis_g = __p(sensor_rgbs)
            np_vis_e = vis_e.cpu().detach().numpy()
            np_vis_e = np.transpose(np_vis_e, [0, 2, 3, 1])
            np_vis_g = vis_g.cpu().detach().numpy()
            np_vis_g = np.transpose(np_vis_g, [0, 2, 3, 1])

            # bring it back to original color
            np_vis_g = ((np_vis_g+0.5) * 255).astype(np.uint8)
            np_vis_e = ((np_vis_e+0.5) * 255).astype(np.uint8)

            # now compare fast_sensor_emb_e with slow_context_emb_g
            # since I am doing positive against this
            fast_sensor_emb_e_list = [fast_sensor_emb_e, np_vis_e]
            slow_context_emb_g_list = [slow_context_emb_g, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_sensor_emb_e_list, slow_context_emb_g_list, recalls=recalls
            )

            # finally plot the nearest neighbour retrieval and move ahead
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_sensor_slow_context')

            # plot the precisions at different recalls
            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_sensor_slow_context/recall@{re}',\
                    prec[pr])

            # now compare fast_context_emb_e with slow_sensor_emb_g
            fast_context_emb_e_list = [fast_context_emb_e, np_vis_e]
            slow_sensor_emb_g_list = [slow_sensor_emb_g, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_context_emb_e_list, slow_sensor_emb_g_list, recalls=recalls
            )
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_context_slow_sensor')

            # plot the precisions at different recalls
            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_context_slow_sensor/recall@{re}',\
                    prec[pr])


            # now finally compare both the fast, I presume we want them to go closer too
            fast_sensor_list = [fast_sensor_emb_e, np_vis_e]
            fast_context_list = [fast_context_emb_e, np_vis_g]

            prec, vis, chosen_inds_and_neighbors_inds = compute_precision(
                fast_sensor_list, fast_context_list, recalls=recalls
            )
            if feed['global_step'] % 1 == 0:
                plot_nearest_neighbours(vis, step=feed['global_step'],
                                        save_dir='/home/gauravp/eval_results',
                                        name='fast_sensor_fast_context')

            for pr, re in enumerate(recalls):
                summ_writer.summ_scalar(f'evrefast_sensor_fast_context/recall@{re}',\
                    prec[pr])

        ## ... done code for filling up results needed for eval recall ... ##
        summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, [key_sensor_features_, key_pred_features_]
Esempio n. 10
0
    def forward(self, feed):
        results = dict()

        if 'log_freq' not in feed.keys():
            feed['log_freq'] = None
        start_time = time.time()

        summ_writer = utils_improc.Summ_writer(writer=feed['writer'],
                                               global_step=feed['global_step'],
                                               set_name=feed['set_name'],
                                               log_freq=feed['log_freq'],
                                               fps=8)
        writer = feed['writer']
        global_step = feed['global_step']

        total_loss = torch.tensor(0.0).cuda()
        __p = lambda x: utils_basic.pack_seqdim(x, B)
        __u = lambda x: utils_basic.unpack_seqdim(x, B)

        __pb = lambda x: utils_basic.pack_boxdim(x, hyp.N)
        __ub = lambda x: utils_basic.unpack_boxdim(x, hyp.N)
        if hyp.aug_object_ent_dis:
            __pb_a = lambda x: utils_basic.pack_boxdim(
                x, hyp.max_obj_aug + hyp.max_obj_aug_dis)
            __ub_a = lambda x: utils_basic.unpack_boxdim(
                x, hyp.max_obj_aug + hyp.max_obj_aug_dis)
        else:
            __pb_a = lambda x: utils_basic.pack_boxdim(x, hyp.max_obj_aug)
            __ub_a = lambda x: utils_basic.unpack_boxdim(x, hyp.max_obj_aug)

        B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N
        PH, PW = hyp.PH, hyp.PW
        K = hyp.K
        BOX_SIZE = hyp.BOX_SIZE
        Z, Y, X = hyp.Z, hyp.Y, hyp.X
        Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2)
        Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4)
        D = 9

        tids = torch.from_numpy(np.reshape(np.arange(B * N), [B, N]))

        rgb_camXs = feed["rgb_camXs_raw"]
        pix_T_cams = feed["pix_T_cams_raw"]
        camRs_T_origin = feed["camR_T_origin_raw"]
        origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin)))
        origin_T_camXs = feed["origin_T_camXs_raw"]
        camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0)
        camRs_T_camXs = __u(
            torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)),
                         __p(origin_T_camXs)))
        camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs)))
        camX0_T_camRs = camXs_T_camRs[:, 0]
        camX1_T_camRs = camXs_T_camRs[:, 1]

        camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs)

        xyz_camXs = feed["xyz_camXs_raw"]
        depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(
            __p(pix_T_cams), __p(xyz_camXs), H, W)
        dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_,
                                                       __p(pix_T_cams))

        xyz_camRs = __u(
            utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs)))
        xyz_camX0s = __u(
            utils_geom.apply_4x4(__p(camX0_T_camXs), __p(xyz_camXs)))

        occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X))

        occXs_to_Rs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, occXs)
        occXs_to_Rs_45 = cross_corr.rotate_tensor_along_y_axis(occXs_to_Rs, 45)
        occXs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z2, Y2, X2))
        occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2))
        occX0s_half = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z2, Y2, X2))

        unpXs = __u(
            utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X,
                                           __p(pix_T_cams)))

        unpXs_half = __u(
            utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z2, Y2, X2,
                                           __p(pix_T_cams)))

        unpX0s_half = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z2, Y2, X2,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camX0_T_camXs)))))

        unpRs = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z, Y, X,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camRs_T_camXs)))))

        unpRs_half = __u(
            utils_vox.unproject_rgb_to_mem(
                __p(rgb_camXs), Z2, Y2, X2,
                utils_basic.matmul2(
                    __p(pix_T_cams),
                    utils_geom.safe_inverse(__p(camRs_T_camXs)))))

        dense_xyz_camRs_ = utils_geom.apply_4x4(__p(camRs_T_camXs),
                                                dense_xyz_camXs_)
        inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y,
                                                X).float()
        inbound_camXs_ = torch.reshape(inbound_camXs_, [B * S, 1, H, W])

        depth_camXs = __u(depth_camXs_)
        valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_)

        summ_writer.summ_oneds('2D_inputs/depth_camXs',
                               torch.unbind(depth_camXs, dim=1),
                               maxdepth=21.0)
        summ_writer.summ_oneds('2D_inputs/valid_camXs',
                               torch.unbind(valid_camXs, dim=1))
        summ_writer.summ_rgbs('2D_inputs/rgb_camXs',
                              torch.unbind(rgb_camXs, dim=1))
        summ_writer.summ_occs('3D_inputs/occXs', torch.unbind(occXs, dim=1))
        summ_writer.summ_unps('3D_inputs/unpXs', torch.unbind(unpXs, dim=1),
                              torch.unbind(occXs, dim=1))

        occRs = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z, Y, X))

        if hyp.do_eval_boxes:
            if hyp.dataset_name == "clevr_vqa":
                gt_boxes_origin_corners = feed['gt_box']
                gt_scores_origin = feed['gt_scores'].detach().cpu().numpy()
                classes = feed['classes']
                scores = gt_scores_origin
                tree_seq_filename = feed['tree_seq_filename']
                gt_boxes_origin = nlu.get_ends_of_corner(
                    gt_boxes_origin_corners)
                gt_boxes_origin_end = torch.reshape(gt_boxes_origin,
                                                    [hyp.B, hyp.N, 2, 3])
                gt_boxes_origin_theta = nlu.get_alignedboxes2thetaformat(
                    gt_boxes_origin_end)
                gt_boxes_origin_corners = utils_geom.transform_boxes_to_corners(
                    gt_boxes_origin_theta)
                gt_boxesR_corners = __ub(
                    utils_geom.apply_4x4(camRs_T_origin[:, 0],
                                         __pb(gt_boxes_origin_corners)))
                gt_boxesR_theta = utils_geom.transform_corners_to_boxes(
                    gt_boxesR_corners)
                gt_boxesR_end = nlu.get_ends_of_corner(gt_boxesR_corners)

            else:
                tree_seq_filename = feed['tree_seq_filename']
                tree_filenames = [
                    join(hyp.root_dataset, i) for i in tree_seq_filename
                    if i != "invalid_tree"
                ]
                invalid_tree_filenames = [
                    join(hyp.root_dataset, i) for i in tree_seq_filename
                    if i == "invalid_tree"
                ]
                num_empty = len(invalid_tree_filenames)
                trees = [pickle.load(open(i, "rb")) for i in tree_filenames]

                len_valid = len(trees)
                if len_valid > 0:
                    gt_boxesR, scores, classes = nlu.trees_rearrange(trees)

                if num_empty > 0:
                    gt_boxesR = np.concatenate([
                        gt_boxesR, empty_gt_boxesR
                    ]) if len_valid > 0 else empty_gt_boxesR
                    scores = np.concatenate([
                        scores, empty_scores
                    ]) if len_valid > 0 else empty_scores
                    classes = np.concatenate([
                        classes, empty_classes
                    ]) if len_valid > 0 else empty_classes

                gt_boxesR = torch.from_numpy(
                    gt_boxesR).cuda().float()  # torch.Size([2, 3, 6])
                gt_boxesR_end = torch.reshape(gt_boxesR, [hyp.B, hyp.N, 2, 3])
                gt_boxesR_theta = nlu.get_alignedboxes2thetaformat(
                    gt_boxesR_end)  #torch.Size([2, 3, 9])
                gt_boxesR_corners = utils_geom.transform_boxes_to_corners(
                    gt_boxesR_theta)

            class_names_ex_1 = "_".join(classes[0])
            summ_writer.summ_text('eval_boxes/class_names', class_names_ex_1)

            gt_boxesRMem_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z2, Y2, X2))
            gt_boxesRMem_end = nlu.get_ends_of_corner(gt_boxesRMem_corners)

            gt_boxesRMem_theta = utils_geom.transform_corners_to_boxes(
                gt_boxesRMem_corners)
            gt_boxesRUnp_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z, Y, X))
            gt_boxesRUnp_end = nlu.get_ends_of_corner(gt_boxesRUnp_corners)

            gt_boxesX0_corners = __ub(
                utils_geom.apply_4x4(camX0_T_camRs, __pb(gt_boxesR_corners)))
            gt_boxesX0Mem_corners = __ub(
                utils_vox.Ref2Mem(__pb(gt_boxesX0_corners), Z2, Y2, X2))

            gt_boxesX0Mem_theta = utils_geom.transform_corners_to_boxes(
                gt_boxesX0Mem_corners)

            gt_boxesX0Mem_end = nlu.get_ends_of_corner(gt_boxesX0Mem_corners)
            gt_boxesX0_end = nlu.get_ends_of_corner(gt_boxesX0_corners)

            gt_cornersX0_pix = __ub(
                utils_geom.apply_pix_T_cam(pix_T_cams[:, 0],
                                           __pb(gt_boxesX0_corners)))

            rgb_camX0 = rgb_camXs[:, 0]
            rgb_camX1 = rgb_camXs[:, 1]

            summ_writer.summ_box_by_corners('eval_boxes/gt_boxescamX0',
                                            rgb_camX0, gt_boxesX0_corners,
                                            torch.from_numpy(scores), tids,
                                            pix_T_cams[:, 0])
            unps_vis = utils_improc.get_unps_vis(unpX0s_half, occX0s_half)
            unp_vis = torch.mean(unps_vis, dim=1)
            unps_visRs = utils_improc.get_unps_vis(unpRs_half, occRs_half)
            unp_visRs = torch.mean(unps_visRs, dim=1)
            unps_visRs_full = utils_improc.get_unps_vis(unpRs, occRs)
            unp_visRs_full = torch.mean(unps_visRs_full, dim=1)
            summ_writer.summ_box_mem_on_unp('eval_boxes/gt_boxesR_mem',
                                            unp_visRs, gt_boxesRMem_end,
                                            scores, tids)

            unpX0s_half = torch.mean(unpX0s_half, dim=1)
            unpX0s_half = nlu.zero_out(unpX0s_half, gt_boxesX0Mem_end, scores)

            occX0s_half = torch.mean(occX0s_half, dim=1)
            occX0s_half = nlu.zero_out(occX0s_half, gt_boxesX0Mem_end, scores)

            summ_writer.summ_unp('3D_inputs/unpX0s', unpX0s_half, occX0s_half)

        if hyp.do_feat:
            featXs_input = torch.cat([occXs, occXs * unpXs], dim=2)
            featXs_input_ = __p(featXs_input)

            freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), __p(occXs_half))
            freeXs = __u(freeXs_)
            visXs = torch.clamp(occXs_half + freeXs, 0.0, 1.0)
            mask_ = None

            if (type(mask_) != type(None)):
                assert (list(mask_.shape)[2:5] == list(
                    featXs_input_.shape)[2:5])

            featXs_, feat_loss = self.featnet(featXs_input_,
                                              summ_writer,
                                              mask=__p(occXs))  #mask_)
            total_loss += feat_loss

            validXs = torch.ones_like(visXs)
            _validX00 = validXs[:, 0:1]
            _validX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:],
                                                     validXs[:, 1:])
            validX0s = torch.cat([_validX00, _validX01], dim=1)
            validRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, validXs)
            visRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, visXs)

            featXs = __u(featXs_)
            _featX00 = featXs[:, 0:1]
            _featX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:],
                                                    featXs[:, 1:])
            featX0s = torch.cat([_featX00, _featX01], dim=1)

            emb3D_e = torch.mean(featX0s[:, 1:], dim=1)
            vis3D_e_R = torch.max(visRs[:, 1:], dim=1)[0]
            emb3D_g = featX0s[:, 0]
            vis3D_g_R = visRs[:, 0]
            validR_combo = torch.min(validRs, dim=1).values

            summ_writer.summ_feats('3D_feats/featXs_input',
                                   torch.unbind(featXs_input, dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/featXs_output',
                                   torch.unbind(featXs, dim=1),
                                   valids=torch.unbind(validXs, dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/featX0s_output',
                                   torch.unbind(featX0s, dim=1),
                                   valids=torch.unbind(
                                       torch.ones_like(validRs), dim=1),
                                   pca=True)
            summ_writer.summ_feats('3D_feats/validRs',
                                   torch.unbind(validRs, dim=1),
                                   pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_e_R', vis3D_e_R, pca=False)
            summ_writer.summ_feat('3D_feats/vis3D_g_R', vis3D_g_R, pca=False)

        if hyp.do_munit:
            object_classes, filenames = nlu.create_object_classes(
                classes, [tree_seq_filename, tree_seq_filename], scores)
            if hyp.do_munit_fewshot:
                emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e)
                emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g)
                emb3D_R = emb3D_e_R
                emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors(
                    [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end,
                    scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE])
                emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2
                content, style = self.munitnet.net.gen_a.encode(emb3D_R_object)
                objects_taken, _ = self.munitnet.net.gen_a.decode(
                    content, style)
                styles = style
                contents = content
            elif hyp.do_3d_style_munit:
                emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e)
                emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g)
                emb3D_R = emb3D_e_R
                # st()
                emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors(
                    [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end,
                    scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE])
                emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2

                camX1_T_R = camXs_T_camRs[:, 1]
                camX0_T_R = camXs_T_camRs[:, 0]
                assert hyp.B == 2
                assert emb3D_e_R_object.shape[0] == 2
                munit_loss, sudo_input_0, sudo_input_1, recon_input_0, recon_input_1, sudo_input_0_cycle, sudo_input_1_cycle, styles, contents, adin = self.munitnet(
                    emb3D_R_object[0:1], emb3D_R_object[1:2])

                if hyp.store_content_style_range:
                    if self.max_content == None:
                        self.max_content = torch.zeros_like(
                            contents[0][0]).cuda() - 100000000
                    if self.min_content == None:
                        self.min_content = torch.zeros_like(
                            contents[0][0]).cuda() + 100000000
                    if self.max_style == None:
                        self.max_style = torch.zeros_like(
                            styles[0][0]).cuda() - 100000000
                    if self.min_style == None:
                        self.min_style = torch.zeros_like(
                            styles[0][0]).cuda() + 100000000
                    self.max_content = torch.max(
                        torch.max(self.max_content, contents[0][0]),
                        contents[1][0])
                    self.min_content = torch.min(
                        torch.min(self.min_content, contents[0][0]),
                        contents[1][0])
                    self.max_style = torch.max(
                        torch.max(self.max_style, styles[0][0]), styles[1][0])
                    self.min_style = torch.min(
                        torch.min(self.min_style, styles[0][0]), styles[1][0])

                    data_to_save = {
                        'max_content': self.max_content.cpu().numpy(),
                        'min_content': self.min_content.cpu().numpy(),
                        'max_style': self.max_style.cpu().numpy(),
                        'min_style': self.min_style.cpu().numpy()
                    }
                    with open('content_style_range.p', 'wb') as f:
                        pickle.dump(data_to_save, f)
                elif hyp.is_contrastive_examples:
                    if hyp.normalize_contrast:
                        content0 = (contents[0] - self.min_content) / (
                            self.max_content - self.min_content + 1e-5)
                        content1 = (contents[1] - self.min_content) / (
                            self.max_content - self.min_content + 1e-5)
                        style0 = (styles[0] - self.min_style) / (
                            self.max_style - self.min_style + 1e-5)
                        style1 = (styles[1] - self.min_style) / (
                            self.max_style - self.min_style + 1e-5)
                    else:
                        content0 = contents[0]
                        content1 = contents[1]
                        style0 = styles[0]
                        style1 = styles[1]

                    # euclid_dist_content = torch.sum(torch.sqrt((content0 - content1)**2))/torch.prod(torch.tensor(content0.shape))
                    # euclid_dist_style = torch.sum(torch.sqrt((style0-style1)**2))/torch.prod(torch.tensor(style0.shape))
                    euclid_dist_content = (content0 - content1).norm(2) / (
                        content0.numel())
                    euclid_dist_style = (style0 -
                                         style1).norm(2) / (style0.numel())

                    content_0_pooled = torch.mean(
                        content0.reshape(list(content0.shape[:2]) + [-1]),
                        dim=-1)
                    content_1_pooled = torch.mean(
                        content1.reshape(list(content1.shape[:2]) + [-1]),
                        dim=-1)

                    euclid_dist_content_pooled = (content_0_pooled -
                                                  content_1_pooled).norm(2) / (
                                                      content_0_pooled.numel())

                    content_0_normalized = content0 / content0.norm()
                    content_1_normalized = content1 / content1.norm()

                    style_0_normalized = style0 / style0.norm()
                    style_1_normalized = style1 / style1.norm()

                    content_0_pooled_normalized = content_0_pooled / content_0_pooled.norm(
                    )
                    content_1_pooled_normalized = content_1_pooled / content_1_pooled.norm(
                    )

                    cosine_dist_content = torch.sum(content_0_normalized *
                                                    content_1_normalized)
                    cosine_dist_style = torch.sum(style_0_normalized *
                                                  style_1_normalized)
                    cosine_dist_content_pooled = torch.sum(
                        content_0_pooled_normalized *
                        content_1_pooled_normalized)

                    print("euclid dist [content, pooled-content, style]: ",
                          euclid_dist_content, euclid_dist_content_pooled,
                          euclid_dist_style)
                    print("cosine sim [content, pooled-content, style]: ",
                          cosine_dist_content, cosine_dist_content_pooled,
                          cosine_dist_style)

            if hyp.run_few_shot_on_munit:
                if (global_step % 300) == 1 or (global_step % 300) == 0:
                    wrong = False
                    try:
                        precision_style = float(self.tp_style) / self.all_style
                        precision_content = float(
                            self.tp_content) / self.all_content
                    except ZeroDivisionError:
                        wrong = True

                    if not wrong:
                        summ_writer.summ_scalar(
                            'precision/unsupervised_precision_style',
                            precision_style)
                        summ_writer.summ_scalar(
                            'precision/unsupervised_precision_content',
                            precision_content)
                        # st()
                    self.embed_list_style = defaultdict(lambda: [])
                    self.embed_list_content = defaultdict(lambda: [])
                    self.tp_style = 0
                    self.all_style = 0
                    self.tp_content = 0
                    self.all_content = 0
                    self.check = False
                elif not self.check and not nlu.check_fill_dict(
                        self.embed_list_content, self.embed_list_style):
                    print("Filling \n")
                    for index, class_val in enumerate(object_classes):

                        if hyp.dataset_name == "clevr_vqa":
                            class_val_content, class_val_style = class_val.split(
                                "/")
                        else:
                            class_val_content, class_val_style = [
                                class_val.split("/")[0],
                                class_val.split("/")[0]
                            ]

                        print(len(self.embed_list_style.keys()), "style class",
                              len(self.embed_list_content), "content class",
                              self.embed_list_content.keys())
                        if len(self.embed_list_style[class_val_style]
                               ) < hyp.few_shot_nums:
                            self.embed_list_style[class_val_style].append(
                                styles[index].squeeze())
                        if len(self.embed_list_content[class_val_content]
                               ) < hyp.few_shot_nums:
                            if hyp.avg_3d:
                                content_val = contents[index]
                                content_val = torch.mean(content_val.reshape(
                                    [content_val.shape[1], -1]),
                                                         dim=-1)
                                # st()
                                self.embed_list_content[
                                    class_val_content].append(content_val)
                            else:
                                self.embed_list_content[
                                    class_val_content].append(
                                        contents[index].reshape([-1]))
                else:
                    self.check = True
                    try:
                        print(float(self.tp_content) / self.all_content)
                        print(float(self.tp_style) / self.all_style)
                    except Exception as e:
                        pass
                    average = True
                    if average:
                        for key, val in self.embed_list_style.items():
                            if isinstance(val, type([])):
                                self.embed_list_style[key] = torch.mean(
                                    torch.stack(val, dim=0), dim=0)

                        for key, val in self.embed_list_content.items():
                            if isinstance(val, type([])):
                                self.embed_list_content[key] = torch.mean(
                                    torch.stack(val, dim=0), dim=0)
                    else:
                        for key, val in self.embed_list_style.items():
                            if isinstance(val, type([])):
                                self.embed_list_style[key] = torch.stack(val,
                                                                         dim=0)

                        for key, val in self.embed_list_content.items():
                            if isinstance(val, type([])):
                                self.embed_list_content[key] = torch.stack(
                                    val, dim=0)
                    for index, class_val in enumerate(object_classes):
                        class_val = class_val
                        if hyp.dataset_name == "clevr_vqa":
                            class_val_content, class_val_style = class_val.split(
                                "/")
                        else:
                            class_val_content, class_val_style = [
                                class_val.split("/")[0],
                                class_val.split("/")[0]
                            ]

                        style_val = styles[index].squeeze().unsqueeze(0)
                        if not average:
                            embed_list_val_style = torch.cat(list(
                                self.embed_list_style.values()),
                                                             dim=0)
                            embed_list_key_style = list(
                                np.repeat(
                                    np.expand_dims(
                                        list(self.embed_list_style.keys()), 1),
                                    hyp.few_shot_nums, 1).reshape([-1]))
                        else:
                            embed_list_val_style = torch.stack(list(
                                self.embed_list_style.values()),
                                                               dim=0)
                            embed_list_key_style = list(
                                self.embed_list_style.keys())
                        embed_list_val_style = utils_basic.l2_normalize(
                            embed_list_val_style, dim=1).permute(1, 0)
                        style_val = utils_basic.l2_normalize(style_val, dim=1)
                        scores_styles = torch.matmul(style_val,
                                                     embed_list_val_style)
                        index_key = torch.argmax(scores_styles,
                                                 dim=1).squeeze()
                        selected_class_style = embed_list_key_style[index_key]
                        self.styles_prediction[class_val_style].append(
                            selected_class_style)
                        if class_val_style == selected_class_style:
                            self.tp_style += 1
                        self.all_style += 1

                        if hyp.avg_3d:
                            content_val = contents[index]
                            content_val = torch.mean(content_val.reshape(
                                [content_val.shape[1], -1]),
                                                     dim=-1).unsqueeze(0)
                        else:
                            content_val = contents[index].reshape(
                                [-1]).unsqueeze(0)
                        if not average:
                            embed_list_val_content = torch.cat(list(
                                self.embed_list_content.values()),
                                                               dim=0)
                            embed_list_key_content = list(
                                np.repeat(
                                    np.expand_dims(
                                        list(self.embed_list_content.keys()),
                                        1), hyp.few_shot_nums,
                                    1).reshape([-1]))
                        else:
                            embed_list_val_content = torch.stack(list(
                                self.embed_list_content.values()),
                                                                 dim=0)
                            embed_list_key_content = list(
                                self.embed_list_content.keys())
                        embed_list_val_content = utils_basic.l2_normalize(
                            embed_list_val_content, dim=1).permute(1, 0)
                        content_val = utils_basic.l2_normalize(content_val,
                                                               dim=1)
                        scores_content = torch.matmul(content_val,
                                                      embed_list_val_content)
                        index_key = torch.argmax(scores_content,
                                                 dim=1).squeeze()
                        selected_class_content = embed_list_key_content[
                            index_key]
                        self.content_prediction[class_val_content].append(
                            selected_class_content)
                        if class_val_content == selected_class_content:
                            self.tp_content += 1

                        self.all_content += 1
            # st()
            munit_loss = hyp.munit_loss_weight * munit_loss

            recon_input_obj = torch.cat([recon_input_0, recon_input_1], dim=0)
            recon_emb3D_R = nlu.update_scene_with_objects(
                emb3D_R, recon_input_obj, gt_boxesRMem_end, scores)

            sudo_input_obj = torch.cat([sudo_input_0, sudo_input_1], dim=0)
            styled_emb3D_R = nlu.update_scene_with_objects(
                emb3D_R, sudo_input_obj, gt_boxesRMem_end, scores)

            styled_emb3D_e_X1 = utils_vox.apply_4x4_to_vox(
                camX1_T_R, styled_emb3D_R)
            styled_emb3D_e_X0 = utils_vox.apply_4x4_to_vox(
                camX0_T_R, styled_emb3D_R)

            emb3D_e_X1 = utils_vox.apply_4x4_to_vox(camX1_T_R, recon_emb3D_R)
            emb3D_e_X0 = utils_vox.apply_4x4_to_vox(camX0_T_R, recon_emb3D_R)

            emb3D_e_X1_og = utils_vox.apply_4x4_to_vox(camX1_T_R, emb3D_R)
            emb3D_e_X0_og = utils_vox.apply_4x4_to_vox(camX0_T_R, emb3D_R)

            emb3D_R_aug_diff = torch.abs(emb3D_R - recon_emb3D_R)

            summ_writer.summ_feat(f'aug_feat/og', emb3D_R)
            summ_writer.summ_feat(f'aug_feat/og_gen', recon_emb3D_R)
            summ_writer.summ_feat(f'aug_feat/og_aug_diff', emb3D_R_aug_diff)

            if hyp.cycle_style_view_loss:
                sudo_input_obj_cycle = torch.cat(
                    [sudo_input_0_cycle, sudo_input_1_cycle], dim=0)
                styled_emb3D_R_cycle = nlu.update_scene_with_objects(
                    emb3D_R, sudo_input_obj_cycle, gt_boxesRMem_end, scores)

                styled_emb3D_e_X0_cycle = utils_vox.apply_4x4_to_vox(
                    camX0_T_R, styled_emb3D_R_cycle)
                styled_emb3D_e_X1_cycle = utils_vox.apply_4x4_to_vox(
                    camX1_T_R, styled_emb3D_R_cycle)
            summ_writer.summ_scalar('munit_loss', munit_loss.cpu().item())
            total_loss += munit_loss

        if hyp.do_occ and hyp.occ_do_cheap:
            occX0_sup, freeX0_sup, _, freeXs = utils_vox.prep_occs_supervision(
                camX0_T_camXs, xyz_camXs, Z2, Y2, X2, agg=True)

            summ_writer.summ_occ('occ_sup/occ_sup', occX0_sup)
            summ_writer.summ_occ('occ_sup/free_sup', freeX0_sup)
            summ_writer.summ_occs('occ_sup/freeXs_sup',
                                  torch.unbind(freeXs, dim=1))
            summ_writer.summ_occs('occ_sup/occXs_sup',
                                  torch.unbind(occXs_half, dim=1))

            occ_loss, occX0s_pred_ = self.occnet(
                torch.mean(featX0s[:, 1:], dim=1), occX0_sup, freeX0_sup,
                torch.max(validX0s[:, 1:], dim=1)[0], summ_writer)
            occX0s_pred = __u(occX0s_pred_)
            total_loss += occ_loss

        if hyp.do_view:
            assert (hyp.do_feat)
            PH, PW = hyp.PH, hyp.PW
            sy = float(PH) / float(hyp.H)
            sx = float(PW) / float(hyp.W)
            assert (sx == 0.5)  # else we need a fancier downsampler
            assert (sy == 0.5)
            projpix_T_cams = __u(
                utils_geom.scale_intrinsics(__p(pix_T_cams), sx, sy))
            # st()

            if hyp.do_munit:
                feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    emb3D_e_X1,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                feat_projX00_og = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    emb3D_e_X1_og,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                # only for checking the style
                styled_feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    styled_emb3D_e_X1,  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)

                if hyp.cycle_style_view_loss:
                    styled_feat_projX00_cycle = utils_vox.apply_pixX_T_memR_to_voxR(
                        projpix_T_cams[:, 0],
                        camX0_T_camXs[:, 1],
                        styled_emb3D_e_X1_cycle,  # use feat1 to predict rgb0
                        hyp.view_depth,
                        PH,
                        PW)

            else:
                feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR(
                    projpix_T_cams[:, 0],
                    camX0_T_camXs[:, 1],
                    featXs[:, 1],  # use feat1 to predict rgb0
                    hyp.view_depth,
                    PH,
                    PW)
            rgb_X00 = utils_basic.downsample(rgb_camXs[:, 0], 2)
            rgb_X01 = utils_basic.downsample(rgb_camXs[:, 1], 2)
            valid_X00 = utils_basic.downsample(valid_camXs[:, 0], 2)

            view_loss, rgb_e, emb2D_e = self.viewnet(feat_projX00, rgb_X00,
                                                     valid_X00, summ_writer,
                                                     "rgb")

            if hyp.do_munit:
                _, rgb_e, emb2D_e = self.viewnet(feat_projX00_og, rgb_X00,
                                                 valid_X00, summ_writer,
                                                 "rgb_og")
            if hyp.do_munit:
                styled_view_loss, styled_rgb_e, styled_emb2D_e = self.viewnet(
                    styled_feat_projX00, rgb_X00, valid_X00, summ_writer,
                    "recon_style")
                if hyp.cycle_style_view_loss:
                    styled_view_loss_cycle, styled_rgb_e_cycle, styled_emb2D_e_cycle = self.viewnet(
                        styled_feat_projX00_cycle, rgb_X00, valid_X00,
                        summ_writer, "recon_style_cycle")

                rgb_input_1 = torch.cat(
                    [rgb_X01[1], rgb_X01[0], styled_rgb_e[0]], dim=2)
                rgb_input_2 = torch.cat(
                    [rgb_X01[0], rgb_X01[1], styled_rgb_e[1]], dim=2)
                complete_vis = torch.cat([rgb_input_1, rgb_input_2], dim=1)
                summ_writer.summ_rgb('munit/munit_recons_vis',
                                     complete_vis.unsqueeze(0))

            if not hyp.do_munit:
                total_loss += view_loss
            else:
                if hyp.basic_view_loss:
                    total_loss += view_loss
                if hyp.style_view_loss:
                    total_loss += styled_view_loss
                if hyp.cycle_style_view_loss:
                    total_loss += styled_view_loss_cycle

        summ_writer.summ_scalar('loss', total_loss.cpu().item())

        if hyp.save_embed_tsne:
            for index, class_val in enumerate(object_classes):
                class_val_content, class_val_style = class_val.split("/")
                style_val = styles[index].squeeze().unsqueeze(0)
                self.cluster_pool.update(style_val, [class_val_style])
                print(self.cluster_pool.num)

            if self.cluster_pool.is_full():
                embeds, classes = self.cluster_pool.fetch()
                with open("offline_cluster" + '/%st.txt' % 'classes',
                          'w') as f:
                    for index, embed in enumerate(classes):
                        class_val = classes[index]
                        f.write("%s\n" % class_val)
                f.close()
                with open("offline_cluster" + '/%st.txt' % 'embeddings',
                          'w') as f:
                    for index, embed in enumerate(embeds):
                        # embed = utils_basic.l2_normalize(embed,dim=0)
                        print("writing {} embed".format(index))
                        embed_l_s = [str(i) for i in embed.tolist()]
                        embed_str = '\t'.join(embed_l_s)
                        f.write("%s\n" % embed_str)
                f.close()
                st()

        return total_loss, results
Esempio n. 11
0
    def forward(self,
                feat,
                obj_lrtlist_cams,
                obj_scorelist_s,
                summ_writer,
                suffix=''):
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = list(feat.shape)
        K, B2, S, D = list(obj_lrtlist_cams.shape)
        assert (B == B2)

        # obj_scorelist_s is K x B x S

        # __p = lambda x: utils_basic.pack_seqdim(x, B)
        # __u = lambda x: utils_basic.unpack_seqdim(x, B)

        # obj_lrtlist_cams is K x B x S x 19
        obj_lrtlist_cams_ = obj_lrtlist_cams.reshape(K * B, S, 19)
        obj_clist_cam_ = utils_geom.get_clist_from_lrtlist(obj_lrtlist_cams_)
        obj_clist_cam = obj_clist_cam_.reshape(K, B, S, 1, 3)
        # obj_clist_cam is K x B x S x 1 x 3
        obj_clist_cam = obj_clist_cam.squeeze(3)
        # obj_clist_cam is K x B x S x 3
        clist_cam = obj_clist_cam.reshape(K * B, S, 3)
        clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # this is K*B x S x 3
        clist_mem = clist_mem.reshape(K, B, S, 3)

        energy_vol = self.conv3d(feat)
        # energy_vol is B x 1 x Z x Y x X
        summ_writer.summ_oned('pri/energy_vol', torch.mean(energy_vol, dim=3))
        summ_writer.summ_histogram('pri/energy_vol_hist', energy_vol)

        # for k in range(K):
        # let's start with the first object
        # loglike_per_traj = self.get_traj_loglike(clist_mem[0], energy_vol)
        # # this is B
        # ce_loss = -1.0*torch.mean(loglike_per_traj)
        # # this is []

        loglike_per_traj = self.get_trajs_loglike(clist_mem, obj_scorelist_s,
                                                  energy_vol)
        # this is B x K
        valid = torch.max(obj_scorelist_s.permute(1, 0, 2), dim=2)[0]
        ce_loss = -1.0 * utils_basic.reduce_masked_mean(
            loglike_per_traj, valid)
        # this is []

        total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss,
                                         hyp.pri_ce_coeff, summ_writer)

        reg_loss = torch.sum(torch.abs(energy_vol))
        total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss,
                                         hyp.pri_reg_coeff, summ_writer)

        # smooth loss
        dz, dy, dx = gradient3D(energy_vol, absolute=True)
        smooth_vox = torch.mean(dx + dy + dx, dim=1, keepdims=True)
        summ_writer.summ_oned('pri/smooth_loss', torch.mean(smooth_vox, dim=3))
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss,
                                         smooth_loss, hyp.pri_smooth_coeff,
                                         summ_writer)

        # pri_e = F.sigmoid(energy_vol)
        # energy_volbinary = torch.round(pri_e)

        # # collect some accuracy stats
        # pri_match = pri_g*torch.eq(energy_volbinary, pri_g).float()
        # free_match = free_g*torch.eq(1.0-energy_volbinary, free_g).float()
        # either_match = torch.clamp(pri_match+free_match, 0.0, 1.0)
        # either_have = torch.clamp(pri_g+free_g, 0.0, 1.0)
        # acc_pri = reduce_masked_mean(pri_match, pri_g*valid)
        # acc_free = reduce_masked_mean(free_match, free_g*valid)
        # acc_total = reduce_masked_mean(either_match, either_have*valid)

        # summ_writer.summ_scalar('pri/acc_pri%s' % suffix, acc_pri.cpu().item())
        # summ_writer.summ_scalar('pri/acc_free%s' % suffix, acc_free.cpu().item())
        # summ_writer.summ_scalar('pri/acc_total%s' % suffix, acc_total.cpu().item())

        # # vis
        # summ_writer.summ_pri('pri/pri_g%s' % suffix, pri_g, reduce_axes=[2,3])
        # summ_writer.summ_pri('pri/free_g%s' % suffix, free_g, reduce_axes=[2,3])
        # summ_writer.summ_pri('pri/pri_e%s' % suffix, pri_e, reduce_axes=[2,3])
        # summ_writer.summ_pri('pri/valid%s' % suffix, valid, reduce_axes=[2,3])

        # prob_loss = self.compute_loss(energy_vol, pri_g, free_g, valid, summ_writer)
        # total_loss = utils_misc.add_loss('pri/prob_loss%s' % suffix, total_loss, prob_loss, hyp.pri_coeff, summ_writer)

        return total_loss  #, pri_e