Exemplo n.º 1
0
    def forward(self, feat_mem, clist_cam, summ_writer, suffix=''):
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = list(feat_mem.shape)
        B2, S, D = list(clist_cam.shape)
        assert (B == B2)
        assert (D == 3)

        clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # this is (still) B x S x 3

        feat_ = feat_mem.permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X)
        mask_ = 1.0 - (feat_ == 0).all(dim=1, keepdim=True).float().cuda()
        grid_ = utils_basic.meshgrid2D(B, Z, X, stack=True,
                                       norm=True).permute(0, 3, 1, 2)
        halfgrid_ = utils_basic.meshgrid2D(B,
                                           int(Z / 2),
                                           int(X / 2),
                                           stack=True,
                                           norm=True).permute(0, 3, 1, 2)
        feat_ = torch.cat([feat_, grid_], dim=1)
        energy_map, mask = self.net(feat_, mask_, halfgrid_)
        # energy_map = self.net(feat_)
        # energy_map is B x 1 x Z x X
        # don't do this: # energy_map = energy_map + (1.0-mask) * (torch.min(torch.min(energy_map, dim=2)[0], dim=2)[0]).reshape(B, 1, 1, 1)
        summ_writer.summ_feat('pri/energy_input', feat_)
        summ_writer.summ_oned('pri/energy_map', energy_map)
        summ_writer.summ_oned('pri/mask', mask, norm=False)
        summ_writer.summ_histogram('pri/energy_map_hist', energy_map)

        loglike_per_traj = utils_misc.get_traj_loglike(
            clist_mem * 0.5, energy_map)  # 0.5 since it's half res
        # loglike_per_traj = self.get_traj_loglike(clist_mem*0.25, energy_map) # 0.25 since it's quarter res
        # this is B x K
        ce_loss = -1.0 * torch.mean(loglike_per_traj)
        # this is []

        total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss,
                                         hyp.pri2D_ce_coeff, summ_writer)

        reg_loss = torch.sum(torch.abs(energy_map))
        total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss,
                                         hyp.pri2D_reg_coeff, summ_writer)

        # smooth loss
        dz, dx = utils_basic.gradient2D(energy_map, absolute=True)
        smooth_vox = torch.mean(dz + dx, dim=1, keepdims=True)
        summ_writer.summ_oned('pri/smooth_loss', smooth_vox)
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss,
                                         smooth_loss, hyp.pri2D_smooth_coeff,
                                         summ_writer)

        return total_loss, energy_map
def crop_and_resize_box2D(im, box2D, Y, X):
    B, C, H, W = list(im.shape)
    B2, D = list(box2D.shape)
    assert (B == B2)
    assert (D == 4)
    grid_y, grid_x = utils_basic.meshgrid2D(B, Y, X, stack=False, norm=True)
    # now the range is [-1,1]

    grid_y = (grid_y + 1.0) / 2.0
    grid_x = (grid_x + 1.0) / 2.0
    # now the range is [0,1]

    h, w = utils_geom.get_size_from_box2D(box2D)
    ymin, xmin, ymax, xmax = torch.unbind(box2D, dim=1)
    grid_y = grid_y * h + ymin
    grid_x = grid_x * w + xmin
    # now the range is (0,1)

    grid_y = (grid_y * 2.0) - 1.0
    grid_x = (grid_x * 2.0) - 1.0
    # now the range is (-1,1)

    xy = torch.stack([grid_x, grid_y], dim=3)
    samp = F.grid_sample(im, xy)
    return samp
def depth2pointcloud(z, pix_T_cam):
    B, C, H, W = list(z.shape)
    y, x = utils_basic.meshgrid2D(B, H, W)
    z = torch.reshape(z, [B, H, W])
    fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
    xyz = Pixels2Camera(x, y, z, fx, fy, x0, y0)
    return xyz
Exemplo n.º 4
0
 def embed(self, discrete_image):
     B, H, W = list(discrete_image.shape)
     # utils_py.print_stats('discrete_image', discrete_image.cpu().detach().numpy())
     emb = self.embed_dict(discrete_image.view(-1)).view(
         B, H, W, self.emb_dim)
     emb = emb.permute(0, 3, 1, 2)  # (B, self.emb_dim, H, W)
     if self.use_grid:
         grid = utils_basic.meshgrid2D(B, H, W, stack=True,
                                       norm=True).permute(0, 3, 1, 2)
         emb = torch.cat([emb, grid], dim=1)
     return emb
def xy2heatmaps(xy, Y, X, sigma=30.0):
    # xy is B x N x 2

    B, N, D = list(xy.shape)
    assert(D==2)
    
    grid_y, grid_x = utils_basic.meshgrid2D(B, Y, X)
    # grid_x and grid_y are B x Y x X
    grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1)
    grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1)
    heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=True)
    return heat
def pointcloud2flow(xyz1, pix_T_cam, H, W):
    # project xyz1 down, so that we get the 2D location of all of these pixels,
    # then subtract these 2D locations from the original ones to get optical flow
    
    B, N, C = list(xyz1.shape)
    assert(N==H*W)
    assert(C==3)
    
    # we assume xyz1 is the unprojection of the regular grid
    grid_y0, grid_x0 = utils_basic.meshgrid2D(B, H, W)

    xy1 = Camera2Pixels(xyz1, pix_T_cam)
    x1, y1 = torch.unbind(xy1, dim=2)
    x1 = x1.reshape(B, H, W)
    y1 = y1.reshape(B, H, W)

    flow_x = x1 - grid_x0
    flow_y = y1 - grid_y0
    flow = torch.stack([flow_x, flow_y], axis=1)
    # flow is B x 2 x H x W
    return flow
Exemplo n.º 7
0
    def forward(self, clist_cam, energy_map, occ_mems, summ_writer):

        total_loss = torch.tensor(0.0).cuda()

        B, S, C, Z, Y, X = list(occ_mems.shape)
        B2, S, D = list(clist_cam.shape)
        assert (B == B2)

        traj_past = clist_cam[:, :self.T_past]
        traj_futu = clist_cam[:, self.T_past:]

        # just xz
        traj_past = torch.stack([traj_past[:, :, 0], traj_past[:, :, 2]],
                                dim=2)  # xz
        traj_futu = torch.stack([traj_futu[:, :, 0], traj_futu[:, :, 2]],
                                dim=2)  # xz

        feat = occ_mems[:, 0].permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X)
        mask = 1.0 - (feat == 0).all(dim=1, keepdim=True).float().cuda()
        halfgrid = utils_basic.meshgrid2D(B,
                                          int(Z / 2),
                                          int(X / 2),
                                          stack=True,
                                          norm=True).permute(0, 3, 1, 2)
        feat_map, _ = self.compressor(feat, mask, halfgrid)
        pred_map = self.conv2d(feat_map)
        # these are B x C x Z x X

        K = 12  # number of samples
        traj_past = traj_past.unsqueeze(0).repeat(K, 1, 1, 1)
        feat_map = feat_map.unsqueeze(0).repeat(K, 1, 1, 1, 1)
        pred_map = pred_map.unsqueeze(0).repeat(K, 1, 1, 1, 1)
        # to sample the K trajectories in parallel, we'll pack K onto the batch dim
        __p = lambda x: utils_basic.pack_seqdim(x, K)
        __u = lambda x: utils_basic.unpack_seqdim(x, K)
        traj_past_ = __p(traj_past)
        feat_map_ = __p(feat_map)
        pred_map_ = __p(pred_map)
        base_sample_ = torch.randn(K * B, self.T_futu, 2).cuda()
        traj_futu_e_ = self.compute_forward_mapping(feat_map_, pred_map_,
                                                    base_sample_, traj_past_)
        traj_futu_e = __u(traj_futu_e_)
        # this is K x B x T x 2

        # print('traj_futu_e', traj_futu_e.shape, traj_futu_e[0,0])
        if summ_writer.save_this:
            o = []
            for k in list(range(K)):
                o.append(
                    utils_improc.preprocess_color(
                        summ_writer.summ_traj_on_occ(
                            '',
                            utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]),
                                              Z, Y, X),
                            occ_mems[:, 0],
                            already_mem=True,
                            only_return=True)))
                summ_writer.summ_traj_on_occ(
                    'rponet/traj_futu_sample_%d' % k,
                    utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]), Z, Y,
                                      X),
                    occ_mems[:, 0],
                    already_mem=True)

            mean_vis = torch.max(torch.stack(o, dim=0), dim=0)[0]
            summ_writer.summ_rgb('rponet/traj_futu_e_mean', mean_vis)

            summ_writer.summ_traj_on_occ('rponet/traj_futu_g',
                                         utils_vox.Ref2Mem(
                                             self.add_fake_y(traj_futu), Z, Y,
                                             X),
                                         occ_mems[:, 0],
                                         already_mem=True)

        # forward loss: neg logprob of GT samples under the model
        # reverse loss: neg logprob of estim samples under the (approx) GT (i.e., spatial prior)
        forward_loss, reverse_loss = self.compute_loss(feat_map[0],
                                                       pred_map[0],
                                                       traj_past[0], traj_futu,
                                                       traj_futu_e, energy_map)
        total_loss = utils_misc.add_loss('rpo/forward_loss', total_loss,
                                         forward_loss, hyp.rpo2D_forward_coeff,
                                         summ_writer)
        total_loss = utils_misc.add_loss('rpo/reverse_loss', total_loss,
                                         reverse_loss, hyp.rpo2D_reverse_coeff,
                                         summ_writer)

        return total_loss
Exemplo n.º 8
0
    def forward(self, clist_cam, occs, summ_writer, vox_util, suffix=''):
        total_loss = torch.tensor(0.0).cuda()
        B, S, C, Z, Y, X = list(occs.shape)
        B2, S2, D = list(clist_cam.shape)
        assert (B == B2, S == S2)
        assert (D == 3)

        if summ_writer.save_this:
            summ_writer.summ_traj_on_occ('motioncost/actual_traj',
                                         clist_cam,
                                         occs[:, self.T_past],
                                         vox_util,
                                         sigma=2)

        __p = lambda x: utils_basic.pack_seqdim(x, B)
        __u = lambda x: utils_basic.unpack_seqdim(x, B)

        # occs_ = occs.reshape(B*S, C, Z, Y, X)
        occs_ = __p(occs)
        feats_ = occs_.permute(0, 1, 3, 2, 4).reshape(B * S, C * Y, Z, X)
        masks_ = 1.0 - (feats_ == 0).all(dim=1, keepdim=True).float().cuda()
        halfgrids_ = utils_basic.meshgrid2D(B * S,
                                            int(Z / 2),
                                            int(X / 2),
                                            stack=True,
                                            norm=True).permute(0, 3, 1, 2)
        # feats_ = torch.cat([feats_, grids_], dim=1)
        feats = __u(feats_)
        masks = __u(masks_)
        halfgrids = __u(halfgrids_)
        input_feats = feats[:, :self.T_past]
        input_masks = masks[:, :self.T_past]
        input_halfgrids = halfgrids[:, :self.T_past]
        dense_feats_, _ = self.densifier(__p(input_feats), __p(input_masks),
                                         __p(input_halfgrids))
        dense_feats = __u(dense_feats_)
        super_feat = dense_feats.reshape(B, self.T_past * self.dense_dim,
                                         int(Z / 2), int(X / 2))
        cost_maps = self.motioncoster(super_feat)
        cost_maps = F.interpolate(cost_maps, scale_factor=4, mode='bilinear')
        # this is B x T_futu x Z x X
        cost_maps = cost_maps.clamp(-1000,
                                    1000)  # raquel says this adds stability
        summ_writer.summ_histogram('motioncost/cost_maps_hist', cost_maps)
        summ_writer.summ_oneds('motioncost/cost_maps',
                               torch.unbind(cost_maps.unsqueeze(2), dim=1))

        # next i need to sample some trajectories

        N = hyp.motioncost_num_negs
        sampled_trajs_cam = self.sample_trajs(N, clist_cam)
        # this is B x N x S x 3

        if summ_writer.save_this:
            # for n in list(range(np.min([N, 3]))):
            #     # this is 1 x S x 3
            #     summ_writer.summ_traj_on_occ('motioncost/sample%d_clist' % n,
            #                                  sampled_trajs_cam[0, n].unsqueeze(0),
            #                                  occs[:,self.T_past],
            #                                  # torch.max(occs, dim=1)[0],
            #                                  # torch.zeros([1, 1, Z, Y, X]).float().cuda(),
            #                                  already_mem=False)
            o = []
            for n in list(range(N)):
                o.append(
                    utils_improc.preprocess_color(
                        summ_writer.summ_traj_on_occ(
                            '',
                            sampled_trajs_cam[0, n].unsqueeze(0),
                            occs[0:1, self.T_past],
                            vox_util,
                            only_return=True,
                            sigma=0.5)))
            summ_vis = torch.max(torch.stack(o, dim=0), dim=0)[0]
            summ_writer.summ_rgb('motioncost/all_sampled_trajs', summ_vis)

        # smooth loss
        cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X)
        dz, dx = gradient2D(cost_maps_, absolute=True)
        dt = torch.abs(cost_maps[:, 1:] - cost_maps[:, 0:-1])
        smooth_spatial = torch.mean(dx + dz, dim=1, keepdims=True)
        smooth_time = torch.mean(dt, dim=1, keepdims=True)
        summ_writer.summ_oned('motioncost/smooth_loss_spatial', smooth_spatial)
        summ_writer.summ_oned('motioncost/smooth_loss_time', smooth_time)
        smooth_loss = torch.mean(smooth_spatial) + torch.mean(smooth_time)
        total_loss = utils_misc.add_loss('motioncost/smooth_loss', total_loss,
                                         smooth_loss,
                                         hyp.motioncost_smooth_coeff,
                                         summ_writer)

        # def clamp_xyz(xyz, X, Y, Z):
        #     x, y, z = torch.unbind(xyz, dim=-1)
        #     x = x.clamp(0, X)
        #     y = x.clamp(0, Y)
        #     z = x.clamp(0, Z)
        #     # if zero_y:
        #     #     y = torch.zeros_like(y)
        #     xyz = torch.stack([x,y,z], dim=-1)
        #     return xyz
        def clamp_xz(xz, X, Z):
            x, z = torch.unbind(xz, dim=-1)
            x = x.clamp(0, X)
            z = x.clamp(0, Z)
            xz = torch.stack([x, z], dim=-1)
            return xz

        clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X)
        # this is B x S x 3

        # sampled_trajs_cam is B x N x S x 3
        sampled_trajs_cam_ = sampled_trajs_cam.reshape(B, N * S, 3)
        sampled_trajs_mem_ = utils_vox.Ref2Mem(sampled_trajs_cam_, Z, Y, X)
        sampled_trajs_mem = sampled_trajs_mem_.reshape(B, N, S, 3)
        # this is B x N x S x 3

        xyz_pos_ = clist_mem[:, self.T_past:].reshape(B * self.T_futu, 1, 3)
        xyz_neg_ = sampled_trajs_mem[:, :,
                                     self.T_past:].permute(0, 2, 1, 3).reshape(
                                         B * self.T_futu, N, 3)
        # get rid of y
        xz_pos_ = torch.stack([xyz_pos_[:, :, 0], xyz_pos_[:, :, 2]], dim=2)
        xz_neg_ = torch.stack([xyz_neg_[:, :, 0], xyz_neg_[:, :, 2]], dim=2)
        xz_ = torch.cat([xz_pos_, xz_neg_], dim=1)
        xz_ = clamp_xz(xz_, X, Z)
        cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X)
        cost_ = utils_samp.bilinear_sample2D(cost_maps_, xz_[:, :, 0],
                                             xz_[:, :, 1]).squeeze(1)
        # cost is B*T_futu x 1+N
        cost_pos = cost_[:, 0:1]  # B*T_futu x 1
        cost_neg = cost_[:, 1:]  # B*T_futu x N

        cost_pos = cost_pos.unsqueeze(2)  # B*T_futu x 1 x 1
        cost_neg = cost_neg.unsqueeze(1)  # B*T_futu x 1 x N

        utils_misc.add_loss('motioncost/mean_cost_pos', 0,
                            torch.mean(cost_pos), 0, summ_writer)
        utils_misc.add_loss('motioncost/mean_cost_neg', 0,
                            torch.mean(cost_neg), 0, summ_writer)
        utils_misc.add_loss('motioncost/mean_margin', 0,
                            torch.mean(cost_neg - cost_pos), 0, summ_writer)

        xz_pos = xz_pos_.unsqueeze(2)  # B*T_futu x 1 x 1 x 3
        xz_neg = xz_neg_.unsqueeze(1)  # B*T_futu x 1 x N x 3
        dist = torch.norm(xz_pos - xz_neg, dim=3)  # B*T_futu x 1 x N
        dist = dist / float(
            Z) * 5.0  # normalize for resolution, but upweight it a bit
        margin = F.relu(cost_pos - cost_neg + dist)
        margin = margin.reshape(B, self.T_futu, N)
        # mean over time (in the paper this is a sum)
        margin = torch.mean(margin, dim=1)
        # max over the negatives
        maxmargin = torch.max(margin, dim=1)[0]  # B
        maxmargin_loss = torch.mean(maxmargin)
        total_loss = utils_misc.add_loss('motioncost/maxmargin_loss',
                                         total_loss, maxmargin_loss,
                                         hyp.motioncost_maxmargin_coeff,
                                         summ_writer)

        # now let's see some top k
        # we'll do this for the first el of the batch
        cost_neg = cost_neg.reshape(B, self.T_futu,
                                    N)[0].detach().cpu().numpy()
        futu_mem = sampled_trajs_mem[:, :, self.T_past:].reshape(
            B, N, self.T_futu, 3)[0:1]
        cost_neg = np.reshape(cost_neg, [self.T_futu, N])
        cost_neg = np.sum(cost_neg, axis=0)
        inds = np.argsort(cost_neg, axis=0)

        for n in list(range(2)):
            xyzlist_e_mem = futu_mem[0:1, inds[n]]
            xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X)
            # this is B x S x 3

            if summ_writer.save_this and n == 0:
                print('xyzlist_e_cam', xyzlist_e_cam[0:1])
                print('xyzlist_g_cam', clist_cam[0:1, self.T_past:])

            dist = torch.norm(clist_cam[0:1, self.T_past:] -
                              xyzlist_e_cam[0:1],
                              dim=2)
            # this is B x T_futu
            meandist = torch.mean(dist)
            utils_misc.add_loss('motioncost/xyz_dist_%d' % n, 0, meandist, 0,
                                summ_writer)

        if summ_writer.save_this:
            # plot the best and worst trajs
            # print('sorted costs:', cost_neg[inds])
            for n in list(range(2)):
                ind = inds[n]
                print('plotting good traj with cost %.2f' % (cost_neg[ind]))
                xyzlist_e_mem = sampled_trajs_mem[:, ind]
                # this is 1 x S x 3
                summ_writer.summ_traj_on_occ('motioncost/best_sampled_traj%d' %
                                             n,
                                             xyzlist_e_mem[0:1],
                                             occs[0:1, self.T_past],
                                             vox_util,
                                             already_mem=True,
                                             sigma=2)

            for n in list(range(2)):
                ind = inds[-(n + 1)]
                print('plotting bad traj with cost %.2f' % (cost_neg[ind]))
                xyzlist_e_mem = sampled_trajs_mem[:, ind]
                # this is 1 x S x 3
                summ_writer.summ_traj_on_occ(
                    'motioncost/worst_sampled_traj%d' % n,
                    xyzlist_e_mem[0:1],
                    occs[0:1, self.T_past],
                    vox_util,
                    already_mem=True,
                    sigma=2)

        # xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X)
        # xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X)
        # summ_writer.summ_traj_on_occ('motioncost/traj_e',
        #                              xyzlist_e_mem,
        #                              torch.max(occs, dim=1)[0],
        #                              already_mem=True,
        #                              sigma=2)
        # summ_writer.summ_traj_on_occ('motioncost/traj_g',
        #                              xyzlist_g_mem,
        #                              torch.max(occs, dim=1)[0],
        #                              already_mem=True,
        #                              sigma=2)

        # scorelist_here = scorelist[:,self.num_given:,0]
        # sql2 = torch.sum((vel_g-vel_e)**2, dim=2)

        # ## yes weightmask
        # weightmask = torch.arange(0, self.num_need, dtype=torch.float32, device=torch.device('cuda'))
        # weightmask = torch.exp(-weightmask**(1./4))
        # # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860,
        # #         0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353
        # weightmask = weightmask.reshape(1, self.num_need)
        # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here * weightmask)

        # utils_misc.add_loss('motioncost/l2_loss', 0, l2_loss, 0, summ_writer)

        # # # no weightmask:
        # # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here)

        # # total_loss = utils_misc.add_loss('motioncost/l2_loss', total_loss, l2_loss, hyp.motioncost_l2_coeff, summ_writer)

        # dist = torch.norm(xyzlist_e - xyzlist_g, dim=2)
        # meandist = utils_basic.reduce_masked_mean(dist, scorelist[:,:,0])
        # utils_misc.add_loss('motioncost/xyz_dist_0', 0, meandist, 0, summ_writer)

        # l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here)
        # # utils_misc.add_loss('motioncost/vel_dist_noexp', 0, l2_loss, 0, summ_writer)
        # total_loss = utils_misc.add_loss('motioncost/l2_loss_noexp', total_loss, l2_loss_noexp, hyp.motioncost_l2_coeff, summ_writer)

        return total_loss