def compute_loss(self, feat_map, pred_map, traj_past, traj_futu, traj_futu_e, energy_map): # traj_past is B x T+1 x 2 # traj_futu is B x T x 2 # traj_futu_e is K x B x T x 2 # energy_map is B x H x W x 1 K, B, T_futu, D = list(traj_futu_e.shape) # ------------ # forward loss # ------------ # run the actual traj_futu through the net, to get its log prob CE_pqs = -1.0 * self.log_prob(feat_map, pred_map, traj_past, traj_futu) # this is B CE_pq = torch.mean(CE_pqs) # summ_writer.summ_scalar('rpo/CE_pq', CE_pq.cpu().item()) # ------------ # reverse loss # ------------ B, C, Z, X = list(energy_map.shape) xyz_cam = traj_futu_e.reshape([B * K, T_futu, 2]) xyz_mem = utils_vox.Ref2Mem(self.add_fake_y(xyz_cam), Z, 10, X) # since we have multiple samples per image, we need to tile up energy_map energy_map = energy_map.unsqueeze(0).repeat(K, 1, 1, 1, 1).reshape(K * B, C, Z, X) CE_qphat_all = -1.0 * utils_misc.get_traj_loglike(xyz_mem, energy_map) # this is B*K x T_futu CE_qphat = torch.mean(CE_qphat_all) _forward_loss = CE_pq _reverse_loss = CE_qphat return _forward_loss, _reverse_loss
def forward(self, feat_mem, clist_cam, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat_mem.shape) B2, S, D = list(clist_cam.shape) assert (B == B2) assert (D == 3) clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X) # this is (still) B x S x 3 feat_ = feat_mem.permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X) mask_ = 1.0 - (feat_ == 0).all(dim=1, keepdim=True).float().cuda() grid_ = utils_basic.meshgrid2D(B, Z, X, stack=True, norm=True).permute(0, 3, 1, 2) halfgrid_ = utils_basic.meshgrid2D(B, int(Z / 2), int(X / 2), stack=True, norm=True).permute(0, 3, 1, 2) feat_ = torch.cat([feat_, grid_], dim=1) energy_map, mask = self.net(feat_, mask_, halfgrid_) # energy_map = self.net(feat_) # energy_map is B x 1 x Z x X # don't do this: # energy_map = energy_map + (1.0-mask) * (torch.min(torch.min(energy_map, dim=2)[0], dim=2)[0]).reshape(B, 1, 1, 1) summ_writer.summ_feat('pri/energy_input', feat_) summ_writer.summ_oned('pri/energy_map', energy_map) summ_writer.summ_oned('pri/mask', mask, norm=False) summ_writer.summ_histogram('pri/energy_map_hist', energy_map) loglike_per_traj = utils_misc.get_traj_loglike( clist_mem * 0.5, energy_map) # 0.5 since it's half res # loglike_per_traj = self.get_traj_loglike(clist_mem*0.25, energy_map) # 0.25 since it's quarter res # this is B x K ce_loss = -1.0 * torch.mean(loglike_per_traj) # this is [] total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss, hyp.pri2D_ce_coeff, summ_writer) reg_loss = torch.sum(torch.abs(energy_map)) total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss, hyp.pri2D_reg_coeff, summ_writer) # smooth loss dz, dx = utils_basic.gradient2D(energy_map, absolute=True) smooth_vox = torch.mean(dz + dx, dim=1, keepdims=True) summ_writer.summ_oned('pri/smooth_loss', smooth_vox) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss, smooth_loss, hyp.pri2D_smooth_coeff, summ_writer) return total_loss, energy_map
def prep_dyn_feat(self, pred_map, traj_past): B, C, Z, X = list(pred_map.shape) B, T, D = list(traj_past.shape) # as the "dynamic" input, we will sample from pred_map at each loc in traj_past traj_past_mem = utils_vox.Ref2Mem(self.add_fake_y(traj_past), Z, 10, X) x = traj_past_mem[:, :, 0] z = traj_past_mem[:, :, 2] feats = utils_samp.bilinear_sample2D(pred_map, x, z) # print('sampled these:', feats.shape) # this is B x T x C dyn_feat = feats.reshape(B, -1) # also, we will concat the actual traj_past dyn_feat = torch.cat([dyn_feat, traj_past.reshape(B, -1)], axis=1) # print('cat the traj itself, and got:', dyn_feat.shape) return dyn_feat
def summ_traj_on_occ(self, name, traj, occ_mem, already_mem=False, sigma=2): # traj is B x S x 3 B, C, Z, Y, X = list(occ_mem.shape) B2, S, D = list(traj.shape) assert(D==3) assert(B==B2) if self.save_this: if already_mem: traj_mem = traj else: traj_mem = utils_vox.Ref2Mem(traj, Z, Y, X) height_mem = convert_occ_to_height(occ_mem, reduce_axis=3) # this is B x C x Z x X occ_vis = normalize(height_mem) occ_vis = oned2inferno(occ_vis, norm=False) # print(vis.shape) x, y, z = torch.unbind(traj_mem, dim=2) xz = torch.stack([x,z], dim=2) heats = draw_circles_at_xy(xz, Z, X, sigma=sigma) # this is B x S x 1 x Z x X heats = torch.squeeze(heats, dim=2) heat = seq2color(heats) # make black 0 heat = back2color(heat) # print(heat.shape) # vis[heat > 0] = heat # replace black with occ vis heat[heat==0] = (occ_vis[heat==0].float()*0.5).byte() # darken the bkg a bit heat = preprocess_color(heat) self.summ_rgb(('%s' % (name)), heat)
def __init__(self): super(ForecastNet, self).__init__() print('ForecastNet...') self.use_cost_vols = False if self.use_cost_vols: if hyp.do_feat: in_dim = 66 else: in_dim = 10 hidden_dim = 32 out_dim = hyp.S self.cost_forecaster = archs.encoder3D.Net3D( in_channel=in_dim, pred_dim=out_dim).cuda() # self.cost_forecaster = archs.encoder3D.ResNet3D( # in_channel=in_dim, pred_dim=out_dim).cuda() # self.cost_forecaster = nn.Sequential( # nn.ReplicationPad3d(1), # nn.Conv3d(in_channels=in_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=0), # nn.BatchNorm3d(num_features=hidden_dim), # nn.LeakyReLU(), # nn.ReplicationPad3d(1), # nn.Conv3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=0), # nn.BatchNorm3d(num_features=hidden_dim), # nn.LeakyReLU(), # nn.ConvTranspose3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=1), # nn.BatchNorm3d(num_features=hidden_dim), # nn.LeakyReLU(), # nn.ConvTranspose3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=4, stride=2, padding=1), # nn.BatchNorm3d(num_features=hidden_dim), # nn.LeakyReLU(), # nn.Conv3d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, stride=1, padding=1), # nn.BatchNorm3d(num_features=hidden_dim), # nn.LeakyReLU(), # nn.Conv3d(in_channels=hidden_dim, out_channels=out_dim, kernel_size=1, stride=1), # ).cuda() library_mod = 'ab' traj_library_cam = np.load('../intphys_data/all_trajs_%s.npy' % library_mod) # traj_library_cam is L x 100, where L is huge # let's make it a bit more huge traj_library_cam = np.concatenate([ traj_library_cam * 0.8, traj_library_cam * 1.0, traj_library_cam * 1.2 ], axis=0) # print('traj_library_cam', traj_library_cam.shape) self.L = traj_library_cam.shape[0] self.frame_stride = 2 traj_library_cam = traj_library_cam[:, ::self.frame_stride] # traj_library_cam is L x M x 3 self.M = traj_library_cam.shape[1] Z, Y, X = hyp.Z, hyp.Y, hyp.X Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) traj_library_cam = torch.from_numpy( traj_library_cam).float().cuda() traj_library_cam_ = traj_library_cam.reshape(1, self.L * self.M, 3) traj_library_mem_ = utils_vox.Ref2Mem(traj_library_cam_, Z2, Y2, X2) traj_library_mem = traj_library_mem_.reshape(self.L, self.M, 3) # print('traj_library_mem', traj_library_mem.shape) # self.traj_library_mem = traj_library_mem self.traj_library_mem = traj_library_mem.detach().cpu().numpy() # self.traj_library_mem is L x M x 3 else: self.num_given = 3 self.num_need = hyp.S - self.num_given in_dim = 3 out_dim = self.num_need * 3 # self.regressor = archs.bottle3D.Bottle3D( # in_channel=in_dim, pred_dim=out_dim).cuda() self.regressor = archs.sparse_invar_bottle3D.Bottle3D( in_channel=in_dim, pred_dim=out_dim).cuda()
def forward(self, feats, xyzlist_cam, scorelist, vislist, occs, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, S, C, Z2, Y2, X2 = list(feats.shape) B, S, C, Z, Y, X = list(occs.shape) B2, S2, D = list(xyzlist_cam.shape) assert (B == B2, S == S2) assert (D == 3) xyzlist_mem = utils_vox.Ref2Mem(xyzlist_cam, Z, Y, X) # these are B x S x 3 scorelist = scorelist.unsqueeze(2) # this is B x S x 1 vislist = vislist[:, 0].reshape(B, 1, 1) # we only care that the object was visible in frame0 scorelist = scorelist * vislist if self.use_cost_vols: if summ_writer.save_this: summ_writer.summ_traj_on_occ('forecast/actual_traj', xyzlist_mem * scorelist, torch.max(occs, dim=1)[0], already_mem=True, sigma=2) Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4) occ_hint0 = utils_vox.voxelize_xyz(xyzlist_cam[:, 0:1], Z4, Y4, X4) occ_hint1 = utils_vox.voxelize_xyz(xyzlist_cam[:, 1:2], Z4, Y4, X4) occ_hint0 = occ_hint0 * scorelist[:, 0].reshape(B, 1, 1, 1, 1) occ_hint1 = occ_hint1 * scorelist[:, 1].reshape(B, 1, 1, 1, 1) occ_hint = torch.cat([occ_hint0, occ_hint1], dim=1) occ_hint = F.interpolate(occ_hint, scale_factor=4, mode='nearest') # this is B x 1 x Z x Y x X summ_writer.summ_occ('forecast/occ_hint', (occ_hint0 + occ_hint1).clamp(0, 1)) crops = [] for s in list(range(S)): crop = utils_vox.center_mem_on_xyz(occs_highres[:, s], xyzlist_cam[:, s], Z2, Y2, X2) crops.append(crop) crops = torch.stack(crops, dim=0) summ_writer.summ_occs('forecast/crops', crops) # condition on the occ_hint feat = torch.cat([feat, occ_hint], dim=1) N = hyp.forecast_num_negs sampled_trajs_mem = self.sample_trajs_from_library(N, xyzlist_mem) if summ_writer.save_this: for n in list(range(np.min([N, 10]))): xyzlist_mem = sampled_trajs_mem[0, n].unsqueeze(0) # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'forecast/lib%d_xyzlist' % n, xyzlist_mem, torch.zeros([1, 1, Z, Y, X]).float().cuda(), already_mem=True) cost_vols = self.cost_forecaster(feat) # cost_vols = F.sigmoid(cost_vols) cost_vols = F.interpolate(cost_vols, scale_factor=2, mode='trilinear') # cost_vols is B x S x Z x Y x X summ_writer.summ_histogram('forecast/cost_vols_hist', cost_vols) cost_vols = cost_vols.clamp( -1000, 1000) # raquel says this adds stability summ_writer.summ_histogram('forecast/cost_vols_clamped_hist', cost_vols) cost_vols_vis = torch.mean(cost_vols, dim=3).unsqueeze(2) # cost_vols_vis is B x S x 1 x Z x X summ_writer.summ_oneds('forecast/cost_vols_vis', torch.unbind(cost_vols_vis, dim=1)) # smooth loss cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X) dz, dy, dx = gradient3D(cost_vols_, absolute=True) dt = torch.abs(cost_vols[:, 1:] - cost_vols[:, 0:-1]) smooth_vox_spatial = torch.mean(dx + dy + dz, dim=1, keepdims=True) smooth_vox_time = torch.mean(dt, dim=1, keepdims=True) summ_writer.summ_oned('forecast/smooth_loss_spatial', torch.mean(smooth_vox_spatial, dim=3)) summ_writer.summ_oned('forecast/smooth_loss_time', torch.mean(smooth_vox_time, dim=3)) smooth_loss = torch.mean(smooth_vox_spatial) + torch.mean( smooth_vox_time) total_loss = utils_misc.add_loss('forecast/smooth_loss', total_loss, smooth_loss, hyp.forecast_smooth_coeff, summ_writer) def clamp_xyz(xyz, X, Y, Z): x, y, z = torch.unbind(xyz, dim=-1) x = x.clamp(0, X) y = x.clamp(0, Y) z = x.clamp(0, Z) xyz = torch.stack([x, y, z], dim=-1) return xyz # obj_xyzlist_mem is K x B x S x 3 # xyzlist_mem is B x S x 3 # sampled_trajs_mem is B x N x S x 3 xyz_pos_ = xyzlist_mem.reshape(B * S, 1, 3) xyz_neg_ = sampled_trajs_mem.permute(0, 2, 1, 3).reshape(B * S, N, 3) # xyz_pos_ = clamp_xyz(xyz_pos_, X, Y, Z) # xyz_neg_ = clamp_xyz(xyz_neg_, X, Y, Z) xyz_ = torch.cat([xyz_pos_, xyz_neg_], dim=1) xyz_ = clamp_xyz(xyz_, X, Y, Z) cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X) x, y, z = torch.unbind(xyz_, dim=2) # x = x.clamp(0, X) # y = x.clamp(0, Y) # z = x.clamp(0, Z) cost_ = utils_samp.bilinear_sample3D(cost_vols_, x, y, z).squeeze(1) # cost is B*S x 1+N cost_pos = cost_[:, 0:1] # B*S x 1 cost_neg = cost_[:, 1:] # B*S x N cost_pos = cost_pos.unsqueeze(2) # B*S x 1 x 1 cost_neg = cost_neg.unsqueeze(1) # B*S x 1 x N utils_misc.add_loss('forecast/mean_cost_pos', 0, torch.mean(cost_pos), 0, summ_writer) utils_misc.add_loss('forecast/mean_cost_neg', 0, torch.mean(cost_neg), 0, summ_writer) utils_misc.add_loss('forecast/mean_margin', 0, torch.mean(cost_neg - cost_pos), 0, summ_writer) xyz_pos = xyz_pos_.unsqueeze(2) # B*S x 1 x 1 x 3 xyz_neg = xyz_neg_.unsqueeze(1) # B*S x 1 x N x 3 dist = torch.norm(xyz_pos - xyz_neg, dim=3) # B*S x 1 x N dist = dist / float( Z) * 5.0 # normalize for resolution, but upweight it a bit margin = F.relu(cost_pos - cost_neg + dist) margin = margin.reshape(B, S, N) # mean over time (in the paper this is a sum) margin = utils_basic.reduce_masked_mean(margin, scorelist.repeat(1, 1, N), dim=1) # max over the negatives maxmargin = torch.max(margin, dim=1)[0] # B maxmargin_loss = torch.mean(maxmargin) total_loss = utils_misc.add_loss('forecast/maxmargin_loss', total_loss, maxmargin_loss, hyp.forecast_maxmargin_coeff, summ_writer) cost_neg = cost_neg.reshape(B, S, N)[0].detach().cpu().numpy() sampled_trajs_mem = sampled_trajs_mem.reshape(B, N, S, 3)[0:1] cost_neg = np.reshape(cost_neg, [S, N]) cost_neg = np.sum(cost_neg, axis=0) inds = np.argsort(cost_neg, axis=0) for n in list(range(2)): xyzlist_e_mem = sampled_trajs_mem[0:1, inds[n]] xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X) # this is B x S x 3 # if summ_writer.save_this and n==0: # print('xyzlist_e_cam', xyzlist_e_cam[0:1]) # print('xyzlist_g_cam', xyzlist_cam[0:1]) # print('scorelist', scorelist[0:1]) dist = torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2) # this is B x S meandist = utils_basic.reduce_masked_mean( dist, scorelist[0:1].squeeze(2)) utils_misc.add_loss('forecast/xyz_dist_%d' % n, 0, meandist, 0, summ_writer) # dist = torch.mean(torch.sum(torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2), dim=1)) # mpe = torch.mean(torch.norm(xyzlist_cam[0:1,int(S/2)] - xyzlist_e_cam[0:1,int(S/2)], dim=1)) # mpe = utils_basic.reduce_masked_mean(dist, scorelist[0:1]) # utils_misc.add_loss('forecast/xyz_mpe_%d' % n, 0, dist, 0, summ_writer) # epe = torch.mean(torch.norm(xyzlist_cam[0:1,-1] - xyzlist_e_cam[0:1,-1], dim=1)) # utils_misc.add_loss('forecast/xyz_epe_%d' % n, 0, dist, 0, summ_writer) if summ_writer.save_this: # plot the best and worst trajs # print('sorted costs:', cost_neg[inds]) for n in list(range(2)): ind = inds[n] # print('plotting good traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'forecast/best_sampled_traj%d' % n, xyzlist_e_mem, torch.max(occs[0:1], dim=1)[0], # torch.zeros([1, 1, Z, Y, X]).float().cuda(), already_mem=True, sigma=1) for n in list(range(2)): ind = inds[-(n + 1)] # print('plotting bad traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'forecast/worst_sampled_traj%d' % n, xyzlist_e_mem, torch.max(occs[0:1], dim=1)[0], # torch.zeros([1, 1, Z, Y, X]).float().cuda(), already_mem=True, sigma=1) else: # use some timesteps as input feat_input = feats[:, :self.num_given].squeeze(2) # feat_input is B x self.num_given x ZZ x ZY x ZX ## regular bottle3D # vel_e = self.regressor(feat_input) ## sparse-invar bottle3D comp_mask = 1.0 - (feat_input == 0).all(dim=1, keepdim=True).float() summ_writer.summ_feat('forecast/feat_input', feat_input, pca=False) summ_writer.summ_feat('forecast/feat_comp_mask', comp_mask, pca=False) vel_e = self.regressor(feat_input, comp_mask) vel_e = vel_e.reshape(B, self.num_need, 3) vel_g = xyzlist_cam[:, self.num_given:] - xyzlist_cam[:, self.num_given - 1:-1] xyzlist_e = torch.zeros_like(xyzlist_cam) xyzlist_g = torch.zeros_like(xyzlist_cam) for s in list(range(S)): # print('s = %d' % s) if s < self.num_given: # print('grabbing from gt ind %s' % s) xyzlist_e[:, s] = xyzlist_cam[:, s] xyzlist_g[:, s] = xyzlist_cam[:, s] else: # print('grabbing from s-self.num_given, which is ind %d' % (s-self.num_given)) xyzlist_e[:, s] = xyzlist_e[:, s - 1] + vel_e[:, s - self.num_given] xyzlist_g[:, s] = xyzlist_g[:, s - 1] + vel_g[:, s - self.num_given] xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X) xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X) summ_writer.summ_traj_on_occ('forecast/traj_e', xyzlist_e_mem, torch.max(occs, dim=1)[0], already_mem=True, sigma=2) summ_writer.summ_traj_on_occ('forecast/traj_g', xyzlist_g_mem, torch.max(occs, dim=1)[0], already_mem=True, sigma=2) scorelist_here = scorelist[:, self.num_given:, 0] sql2 = torch.sum((vel_g - vel_e)**2, dim=2) ## yes weightmask weightmask = torch.arange(0, self.num_need, dtype=torch.float32, device=torch.device('cuda')) weightmask = torch.exp(-weightmask**(1. / 4)) # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860, # 0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353 weightmask = weightmask.reshape(1, self.num_need) l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here * weightmask) utils_misc.add_loss('forecast/l2_loss', 0, l2_loss, 0, summ_writer) # # no weightmask: # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here) # total_loss = utils_misc.add_loss('forecast/l2_loss', total_loss, l2_loss, hyp.forecast_l2_coeff, summ_writer) dist = torch.norm(xyzlist_e - xyzlist_g, dim=2) meandist = utils_basic.reduce_masked_mean(dist, scorelist[:, :, 0]) utils_misc.add_loss('forecast/xyz_dist_0', 0, meandist, 0, summ_writer) l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here) # utils_misc.add_loss('forecast/vel_dist_noexp', 0, l2_loss, 0, summ_writer) total_loss = utils_misc.add_loss('forecast/l2_loss_noexp', total_loss, l2_loss_noexp, hyp.forecast_l2_coeff, summ_writer) return total_loss
def forward(self, clist_cam, energy_map, occ_mems, summ_writer): total_loss = torch.tensor(0.0).cuda() B, S, C, Z, Y, X = list(occ_mems.shape) B2, S, D = list(clist_cam.shape) assert (B == B2) traj_past = clist_cam[:, :self.T_past] traj_futu = clist_cam[:, self.T_past:] # just xz traj_past = torch.stack([traj_past[:, :, 0], traj_past[:, :, 2]], dim=2) # xz traj_futu = torch.stack([traj_futu[:, :, 0], traj_futu[:, :, 2]], dim=2) # xz feat = occ_mems[:, 0].permute(0, 1, 3, 2, 4).reshape(B, C * Y, Z, X) mask = 1.0 - (feat == 0).all(dim=1, keepdim=True).float().cuda() halfgrid = utils_basic.meshgrid2D(B, int(Z / 2), int(X / 2), stack=True, norm=True).permute(0, 3, 1, 2) feat_map, _ = self.compressor(feat, mask, halfgrid) pred_map = self.conv2d(feat_map) # these are B x C x Z x X K = 12 # number of samples traj_past = traj_past.unsqueeze(0).repeat(K, 1, 1, 1) feat_map = feat_map.unsqueeze(0).repeat(K, 1, 1, 1, 1) pred_map = pred_map.unsqueeze(0).repeat(K, 1, 1, 1, 1) # to sample the K trajectories in parallel, we'll pack K onto the batch dim __p = lambda x: utils_basic.pack_seqdim(x, K) __u = lambda x: utils_basic.unpack_seqdim(x, K) traj_past_ = __p(traj_past) feat_map_ = __p(feat_map) pred_map_ = __p(pred_map) base_sample_ = torch.randn(K * B, self.T_futu, 2).cuda() traj_futu_e_ = self.compute_forward_mapping(feat_map_, pred_map_, base_sample_, traj_past_) traj_futu_e = __u(traj_futu_e_) # this is K x B x T x 2 # print('traj_futu_e', traj_futu_e.shape, traj_futu_e[0,0]) if summ_writer.save_this: o = [] for k in list(range(K)): o.append( utils_improc.preprocess_color( summ_writer.summ_traj_on_occ( '', utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]), Z, Y, X), occ_mems[:, 0], already_mem=True, only_return=True))) summ_writer.summ_traj_on_occ( 'rponet/traj_futu_sample_%d' % k, utils_vox.Ref2Mem(self.add_fake_y(traj_futu_e[k]), Z, Y, X), occ_mems[:, 0], already_mem=True) mean_vis = torch.max(torch.stack(o, dim=0), dim=0)[0] summ_writer.summ_rgb('rponet/traj_futu_e_mean', mean_vis) summ_writer.summ_traj_on_occ('rponet/traj_futu_g', utils_vox.Ref2Mem( self.add_fake_y(traj_futu), Z, Y, X), occ_mems[:, 0], already_mem=True) # forward loss: neg logprob of GT samples under the model # reverse loss: neg logprob of estim samples under the (approx) GT (i.e., spatial prior) forward_loss, reverse_loss = self.compute_loss(feat_map[0], pred_map[0], traj_past[0], traj_futu, traj_futu_e, energy_map) total_loss = utils_misc.add_loss('rpo/forward_loss', total_loss, forward_loss, hyp.rpo2D_forward_coeff, summ_writer) total_loss = utils_misc.add_loss('rpo/reverse_loss', total_loss, reverse_loss, hyp.rpo2D_reverse_coeff, summ_writer) return total_loss
def forward(self, clist_cam, occs, summ_writer, vox_util, suffix=''): total_loss = torch.tensor(0.0).cuda() B, S, C, Z, Y, X = list(occs.shape) B2, S2, D = list(clist_cam.shape) assert (B == B2, S == S2) assert (D == 3) if summ_writer.save_this: summ_writer.summ_traj_on_occ('motioncost/actual_traj', clist_cam, occs[:, self.T_past], vox_util, sigma=2) __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) # occs_ = occs.reshape(B*S, C, Z, Y, X) occs_ = __p(occs) feats_ = occs_.permute(0, 1, 3, 2, 4).reshape(B * S, C * Y, Z, X) masks_ = 1.0 - (feats_ == 0).all(dim=1, keepdim=True).float().cuda() halfgrids_ = utils_basic.meshgrid2D(B * S, int(Z / 2), int(X / 2), stack=True, norm=True).permute(0, 3, 1, 2) # feats_ = torch.cat([feats_, grids_], dim=1) feats = __u(feats_) masks = __u(masks_) halfgrids = __u(halfgrids_) input_feats = feats[:, :self.T_past] input_masks = masks[:, :self.T_past] input_halfgrids = halfgrids[:, :self.T_past] dense_feats_, _ = self.densifier(__p(input_feats), __p(input_masks), __p(input_halfgrids)) dense_feats = __u(dense_feats_) super_feat = dense_feats.reshape(B, self.T_past * self.dense_dim, int(Z / 2), int(X / 2)) cost_maps = self.motioncoster(super_feat) cost_maps = F.interpolate(cost_maps, scale_factor=4, mode='bilinear') # this is B x T_futu x Z x X cost_maps = cost_maps.clamp(-1000, 1000) # raquel says this adds stability summ_writer.summ_histogram('motioncost/cost_maps_hist', cost_maps) summ_writer.summ_oneds('motioncost/cost_maps', torch.unbind(cost_maps.unsqueeze(2), dim=1)) # next i need to sample some trajectories N = hyp.motioncost_num_negs sampled_trajs_cam = self.sample_trajs(N, clist_cam) # this is B x N x S x 3 if summ_writer.save_this: # for n in list(range(np.min([N, 3]))): # # this is 1 x S x 3 # summ_writer.summ_traj_on_occ('motioncost/sample%d_clist' % n, # sampled_trajs_cam[0, n].unsqueeze(0), # occs[:,self.T_past], # # torch.max(occs, dim=1)[0], # # torch.zeros([1, 1, Z, Y, X]).float().cuda(), # already_mem=False) o = [] for n in list(range(N)): o.append( utils_improc.preprocess_color( summ_writer.summ_traj_on_occ( '', sampled_trajs_cam[0, n].unsqueeze(0), occs[0:1, self.T_past], vox_util, only_return=True, sigma=0.5))) summ_vis = torch.max(torch.stack(o, dim=0), dim=0)[0] summ_writer.summ_rgb('motioncost/all_sampled_trajs', summ_vis) # smooth loss cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X) dz, dx = gradient2D(cost_maps_, absolute=True) dt = torch.abs(cost_maps[:, 1:] - cost_maps[:, 0:-1]) smooth_spatial = torch.mean(dx + dz, dim=1, keepdims=True) smooth_time = torch.mean(dt, dim=1, keepdims=True) summ_writer.summ_oned('motioncost/smooth_loss_spatial', smooth_spatial) summ_writer.summ_oned('motioncost/smooth_loss_time', smooth_time) smooth_loss = torch.mean(smooth_spatial) + torch.mean(smooth_time) total_loss = utils_misc.add_loss('motioncost/smooth_loss', total_loss, smooth_loss, hyp.motioncost_smooth_coeff, summ_writer) # def clamp_xyz(xyz, X, Y, Z): # x, y, z = torch.unbind(xyz, dim=-1) # x = x.clamp(0, X) # y = x.clamp(0, Y) # z = x.clamp(0, Z) # # if zero_y: # # y = torch.zeros_like(y) # xyz = torch.stack([x,y,z], dim=-1) # return xyz def clamp_xz(xz, X, Z): x, z = torch.unbind(xz, dim=-1) x = x.clamp(0, X) z = x.clamp(0, Z) xz = torch.stack([x, z], dim=-1) return xz clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X) # this is B x S x 3 # sampled_trajs_cam is B x N x S x 3 sampled_trajs_cam_ = sampled_trajs_cam.reshape(B, N * S, 3) sampled_trajs_mem_ = utils_vox.Ref2Mem(sampled_trajs_cam_, Z, Y, X) sampled_trajs_mem = sampled_trajs_mem_.reshape(B, N, S, 3) # this is B x N x S x 3 xyz_pos_ = clist_mem[:, self.T_past:].reshape(B * self.T_futu, 1, 3) xyz_neg_ = sampled_trajs_mem[:, :, self.T_past:].permute(0, 2, 1, 3).reshape( B * self.T_futu, N, 3) # get rid of y xz_pos_ = torch.stack([xyz_pos_[:, :, 0], xyz_pos_[:, :, 2]], dim=2) xz_neg_ = torch.stack([xyz_neg_[:, :, 0], xyz_neg_[:, :, 2]], dim=2) xz_ = torch.cat([xz_pos_, xz_neg_], dim=1) xz_ = clamp_xz(xz_, X, Z) cost_maps_ = cost_maps.reshape(B * self.T_futu, 1, Z, X) cost_ = utils_samp.bilinear_sample2D(cost_maps_, xz_[:, :, 0], xz_[:, :, 1]).squeeze(1) # cost is B*T_futu x 1+N cost_pos = cost_[:, 0:1] # B*T_futu x 1 cost_neg = cost_[:, 1:] # B*T_futu x N cost_pos = cost_pos.unsqueeze(2) # B*T_futu x 1 x 1 cost_neg = cost_neg.unsqueeze(1) # B*T_futu x 1 x N utils_misc.add_loss('motioncost/mean_cost_pos', 0, torch.mean(cost_pos), 0, summ_writer) utils_misc.add_loss('motioncost/mean_cost_neg', 0, torch.mean(cost_neg), 0, summ_writer) utils_misc.add_loss('motioncost/mean_margin', 0, torch.mean(cost_neg - cost_pos), 0, summ_writer) xz_pos = xz_pos_.unsqueeze(2) # B*T_futu x 1 x 1 x 3 xz_neg = xz_neg_.unsqueeze(1) # B*T_futu x 1 x N x 3 dist = torch.norm(xz_pos - xz_neg, dim=3) # B*T_futu x 1 x N dist = dist / float( Z) * 5.0 # normalize for resolution, but upweight it a bit margin = F.relu(cost_pos - cost_neg + dist) margin = margin.reshape(B, self.T_futu, N) # mean over time (in the paper this is a sum) margin = torch.mean(margin, dim=1) # max over the negatives maxmargin = torch.max(margin, dim=1)[0] # B maxmargin_loss = torch.mean(maxmargin) total_loss = utils_misc.add_loss('motioncost/maxmargin_loss', total_loss, maxmargin_loss, hyp.motioncost_maxmargin_coeff, summ_writer) # now let's see some top k # we'll do this for the first el of the batch cost_neg = cost_neg.reshape(B, self.T_futu, N)[0].detach().cpu().numpy() futu_mem = sampled_trajs_mem[:, :, self.T_past:].reshape( B, N, self.T_futu, 3)[0:1] cost_neg = np.reshape(cost_neg, [self.T_futu, N]) cost_neg = np.sum(cost_neg, axis=0) inds = np.argsort(cost_neg, axis=0) for n in list(range(2)): xyzlist_e_mem = futu_mem[0:1, inds[n]] xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X) # this is B x S x 3 if summ_writer.save_this and n == 0: print('xyzlist_e_cam', xyzlist_e_cam[0:1]) print('xyzlist_g_cam', clist_cam[0:1, self.T_past:]) dist = torch.norm(clist_cam[0:1, self.T_past:] - xyzlist_e_cam[0:1], dim=2) # this is B x T_futu meandist = torch.mean(dist) utils_misc.add_loss('motioncost/xyz_dist_%d' % n, 0, meandist, 0, summ_writer) if summ_writer.save_this: # plot the best and worst trajs # print('sorted costs:', cost_neg[inds]) for n in list(range(2)): ind = inds[n] print('plotting good traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ('motioncost/best_sampled_traj%d' % n, xyzlist_e_mem[0:1], occs[0:1, self.T_past], vox_util, already_mem=True, sigma=2) for n in list(range(2)): ind = inds[-(n + 1)] print('plotting bad traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'motioncost/worst_sampled_traj%d' % n, xyzlist_e_mem[0:1], occs[0:1, self.T_past], vox_util, already_mem=True, sigma=2) # xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X) # xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X) # summ_writer.summ_traj_on_occ('motioncost/traj_e', # xyzlist_e_mem, # torch.max(occs, dim=1)[0], # already_mem=True, # sigma=2) # summ_writer.summ_traj_on_occ('motioncost/traj_g', # xyzlist_g_mem, # torch.max(occs, dim=1)[0], # already_mem=True, # sigma=2) # scorelist_here = scorelist[:,self.num_given:,0] # sql2 = torch.sum((vel_g-vel_e)**2, dim=2) # ## yes weightmask # weightmask = torch.arange(0, self.num_need, dtype=torch.float32, device=torch.device('cuda')) # weightmask = torch.exp(-weightmask**(1./4)) # # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860, # # 0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353 # weightmask = weightmask.reshape(1, self.num_need) # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here * weightmask) # utils_misc.add_loss('motioncost/l2_loss', 0, l2_loss, 0, summ_writer) # # # no weightmask: # # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here) # # total_loss = utils_misc.add_loss('motioncost/l2_loss', total_loss, l2_loss, hyp.motioncost_l2_coeff, summ_writer) # dist = torch.norm(xyzlist_e - xyzlist_g, dim=2) # meandist = utils_basic.reduce_masked_mean(dist, scorelist[:,:,0]) # utils_misc.add_loss('motioncost/xyz_dist_0', 0, meandist, 0, summ_writer) # l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here) # # utils_misc.add_loss('motioncost/vel_dist_noexp', 0, l2_loss, 0, summ_writer) # total_loss = utils_misc.add_loss('motioncost/l2_loss_noexp', total_loss, l2_loss_noexp, hyp.motioncost_l2_coeff, summ_writer) return total_loss
def forward(self, feed, moc_init_done=False, debug=False): summ_writer = utils_improc.Summ_writer( writer = feed['writer'], global_step = feed['global_step'], set_name= feed['set_name'], fps=8) writer = feed['writer'] global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() ### ... All things sensor ... ### sensor_rgbs = feed['sensor_imgs'] sensor_depths = feed['sensor_depths'] center_sensor_H, center_sensor_W = sensor_depths[0][0].shape[-1] // 2, sensor_depths[0][0].shape[-2] // 2 ### ... All things sensor end ... ### # 1. Form the memory tensor using the feat net and visual images. # check what all do you need for this and create only those things ## .... Input images .... ## rgb_camRs = feed['rgb_camRs'] rgb_camXs = feed['rgb_camXs'] ## .... Input images end .... ## ## ... Hyperparams ... ## B, H, W, V, S = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S __p = lambda x: pack_seqdim(x, B) __u = lambda x: unpack_seqdim(x, B) PH, PW = hyp.PH, hyp.PW Z, Y, X = hyp.Z, hyp.Y, hyp.X Z2, Y2, X2 = int(Z/2), int(Y/2), int(X/2) ## ... Hyperparams end ... ## ## .... VISUAL TRANSFORMS BEGIN .... ## pix_T_cams = feed['pix_T_cams'] pix_T_cams_ = __p(pix_T_cams) origin_T_camRs = feed['origin_T_camRs'] origin_T_camRs_ = __p(origin_T_camRs) origin_T_camXs = feed['origin_T_camXs'] origin_T_camXs_ = __p(origin_T_camXs) camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse( origin_T_camRs_), origin_T_camXs_) camXs_T_camRs_ = utils_geom.safe_inverse(camRs_T_camXs_) camRs_T_camXs = __u(camRs_T_camXs_) camXs_T_camRs = __u(camXs_T_camRs_) pix_T_cams_ = utils_geom.pack_intrinsics(pix_T_cams_[:, 0, 0], pix_T_cams_[:, 1, 1], pix_T_cams_[:, 0, 2], pix_T_cams_[:, 1, 2]) pix_T_camRs_ = torch.matmul(pix_T_cams_, camXs_T_camRs_) pix_T_camRs = __u(pix_T_camRs_) ## ... VISUAL TRANSFORMS END ... ## ## ... SENSOR TRANSFORMS BEGIN ... ## sensor_origin_T_camXs = feed['sensor_extrinsics'] sensor_origin_T_camXs_ = __p(sensor_origin_T_camXs) sensor_origin_T_camRs = feed['sensor_origin_T_camRs'] sensor_origin_T_camRs_ = __p(sensor_origin_T_camRs) sensor_camRs_T_origin_ = utils_geom.safe_inverse(sensor_origin_T_camRs_) sensor_camRs_T_camXs_ = torch.matmul(utils_geom.safe_inverse( sensor_origin_T_camRs_), sensor_origin_T_camXs_) sensor_camXs_T_camRs_ = utils_geom.safe_inverse(sensor_camRs_T_camXs_) sensor_camRs_T_camXs = __u(sensor_camRs_T_camXs_) sensor_camXs_T_camRs = __u(sensor_camXs_T_camRs_) sensor_pix_T_cams = feed['sensor_intrinsics'] sensor_pix_T_cams_ = __p(sensor_pix_T_cams) sensor_pix_T_cams_ = utils_geom.pack_intrinsics(sensor_pix_T_cams_[:, 0, 0], sensor_pix_T_cams_[:, 1, 1], sensor_pix_T_cams_[:, 0, 2], sensor_pix_T_cams_[:, 1, 2]) sensor_pix_T_camRs_ = torch.matmul(sensor_pix_T_cams_, sensor_camXs_T_camRs_) sensor_pix_T_camRs = __u(sensor_pix_T_camRs_) ## .... SENSOR TRANSFORMS END .... ## ## .... Visual Input point clouds .... ## xyz_camXs = feed['xyz_camXs'] xyz_camXs_ = __p(xyz_camXs) xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, xyz_camXs_) # (40, 4, 4) (B*S, N, 3) xyz_camRs = __u(xyz_camRs_) assert all([torch.allclose(xyz_camR, inp_xyz_camR) for xyz_camR, inp_xyz_camR in zip( xyz_camRs, feed['xyz_camRs'] )]), "computation of xyz_camR here and those computed in input do not match" ## .... Visual Input point clouds end .... ## ## ... Sensor input point clouds ... ## sensor_xyz_camXs = feed['sensor_xyz_camXs'] sensor_xyz_camXs_ = __p(sensor_xyz_camXs) sensor_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_xyz_camXs_) sensor_xyz_camRs = __u(sensor_xyz_camRs_) assert all([torch.allclose(sensor_xyz, inp_sensor_xyz) for sensor_xyz, inp_sensor_xyz in zip( sensor_xyz_camRs, feed['sensor_xyz_camRs'] )]), "the sensor_xyz_camRs computed in forward do not match those computed in input" ## ... visual occupancy computation voxelize the pointcloud from above ... ## occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X) occXs_ = utils_vox.voxelize_xyz(xyz_camXs_, Z, Y, X) occRs_half_ = utils_vox.voxelize_xyz(xyz_camRs_, Z2, Y2, X2) occXs_half_ = utils_vox.voxelize_xyz(xyz_camXs_, Z2, Y2, X2) ## ... visual occupancy computation end ... NOTE: no unpacking ## ## .. visual occupancy computation for sensor inputs .. ## sensor_occRs_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z, Y, X) sensor_occXs_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z, Y, X) sensor_occRs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camRs_, Z2, Y2, X2) sensor_occXs_half_ = utils_vox.voxelize_xyz(sensor_xyz_camXs_, Z2, Y2, X2) ## ... unproject rgb images ... ## unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_camRs_) unpXs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, pix_T_cams_) ## ... unproject rgb finish ... NOTE: no unpacking ## ## ... Make depth images ... ## depth_camXs_, valid_camXs_ = utils_geom.create_depth_image(pix_T_cams_, xyz_camXs_, H, W) dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_, pix_T_cams_) dense_xyz_camRs_ = utils_geom.apply_4x4(camRs_T_camXs_, dense_xyz_camXs_) inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y, X).float() inbound_camXs_ = torch.reshape(inbound_camXs_, [B*S, 1, H, W]) valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_) ## ... Make depth images ... ## ## ... Make sensor depth images ... ## sensor_depth_camXs_, sensor_valid_camXs_ = utils_geom.create_depth_image(sensor_pix_T_cams_, sensor_xyz_camXs_, H, W) sensor_dense_xyz_camXs_ = utils_geom.depth2pointcloud(sensor_depth_camXs_, sensor_pix_T_cams_) sensor_dense_xyz_camRs_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_dense_xyz_camXs_) sensor_inbound_camXs_ = utils_vox.get_inbounds(sensor_dense_xyz_camRs_, Z, Y, X).float() sensor_inbound_camXs_ = torch.reshape(sensor_inbound_camXs_, [B*hyp.sensor_S, 1, H, W]) sensor_valid_camXs = __u(sensor_valid_camXs_) * __u(sensor_inbound_camXs_) ### .. Done making sensor depth images .. ## ### ... Sanity check ... Write to tensorboard ... ### summ_writer.summ_oneds('2D_inputs/depth_camXs', torch.unbind(__u(depth_camXs_), dim=1)) summ_writer.summ_oneds('2D_inputs/valid_camXs', torch.unbind(valid_camXs, dim=1)) summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(rgb_camXs, dim=1)) summ_writer.summ_rgbs('2D_inputs/rgb_camRs', torch.unbind(rgb_camRs, dim=1)) summ_writer.summ_occs('3d_inputs/occXs', torch.unbind(__u(occXs_), dim=1), reduce_axes=[2]) summ_writer.summ_unps('3d_inputs/unpXs', torch.unbind(__u(unpXs_), dim=1),\ torch.unbind(__u(occXs_), dim=1)) # A different approach for viewing occRs of sensors sensor_occRs = __u(sensor_occRs_) vis_sensor_occRs = torch.max(sensor_occRs, dim=1, keepdim=True)[0] # summ_writer.summ_occs('3d_inputs/sensor_occXs', torch.unbind(__u(sensor_occXs_), dim=1), # reduce_axes=[2]) summ_writer.summ_occs('3d_inputs/sensor_occRs', torch.unbind(vis_sensor_occRs, dim=1), reduce_axes=[2]) ### ... code for visualizing sensor depths and sensor rgbs ... ### # summ_writer.summ_oneds('2D_inputs/depths_sensor', torch.unbind(sensor_depths, dim=1)) # summ_writer.summ_rgbs('2D_inputs/rgbs_sensor', torch.unbind(sensor_rgbs, dim=1)) # summ_writer.summ_oneds('2D_inputs/validXs_sensor', torch.unbind(sensor_valid_camXs, dim=1)) if summ_writer.save_this: unpRs_ = utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, matmul2(pix_T_cams_, camXs_T_camRs_)) unpRs = __u(unpRs_) occRs_ = utils_vox.voxelize_xyz(xyz_camRs_, Z, Y, X) summ_writer.summ_occs('3d_inputs/occRs', torch.unbind(__u(occRs_), dim=1), reduce_axes=[2]) summ_writer.summ_unps('3d_inputs/unpRs', torch.unbind(unpRs, dim=1),\ torch.unbind(__u(occRs_), dim=1)) ### ... Sanity check ... Writing to tensoboard complete ... ### results = list() mask_ = None ### ... Visual featnet part .... ### if hyp.do_feat: featXs_input = torch.cat([__u(occXs_), __u(occXs_)*__u(unpXs_)], dim=2) # B, S, 4, H, W, D featXs_input_ = __p(featXs_input) freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), occXs_half_) freeXs = __u(freeXs_) visXs = torch.clamp(__u(occXs_half_) + freeXs, 0.0, 1.0) if type(mask_) != type(None): assert(list(mask_.shape)[2:5] == list(featXs_input.shape)[2:5]) featXs_, validXs_, _ = self.featnet(featXs_input_, summ_writer, mask=occXs_) # total_loss += feat_loss # Note no need of loss validXs, featXs = __u(validXs_), __u(featXs_) # unpacked into B, S, C, D, H, W # bring everything to ref_frame validRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, validXs) visRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, visXs) featRs = utils_vox.apply_4x4_to_voxs(camRs_T_camXs, featXs) # This is now in memory coordinates emb3D_e = torch.mean(featRs[:, 1:], dim=1) # context, or the features of the scene emb3D_g = featRs[:, 0] # this is to predict, basically I will pass emb3D_e as input and hope to predict emb3D_g vis3D_e = torch.max(validRs[:, 1:], dim=1)[0] * torch.max(visRs[:, 1:], dim=1)[0] vis3D_g = validRs[:, 0] * visRs[:, 0] #### ... I do not think I need this ... #### results = {} # # if hyp.do_eval_recall: # # results['emb3D_e'] = emb3D_e # # results['emb3D_g'] = emb3D_g # #### ... Check if you need the above summ_writer.summ_feats('3D_feats/featXs_input', torch.unbind(featXs_input, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featRs_output', torch.unbind(featRs, dim=1), pca=True) summ_writer.summ_feats('3D_feats/validRs', torch.unbind(validRs, dim=1), pca=False) summ_writer.summ_feat('3D_feats/vis3D_e', vis3D_e, pca=False) summ_writer.summ_feat('3D_feats/vis3D_g', vis3D_g, pca=False) # I need to aggregate the features and detach to prevent the backward pass on featnet featRs = torch.mean(featRs, dim=1) featRs = featRs.detach() # ... HERE I HAVE THE VISUAL FEATURE TENSOR ... WHICH IS MADE USING 5 EVENLY SPACED VIEWS # # FOR THE TOUCH PART, I HAVE THE OCC and THE AIM IS TO PREDICT FEATURES FROM THEM # if hyp.do_touch_feat: # 1. Pass all the sensor depth images through the backbone network input_sensor_depths = __p(sensor_depths) sensor_features_ = self.backbone_2D(input_sensor_depths) # should normalize these feature tensors sensor_features_ = l2_normalize(sensor_features_, dim=1) sensor_features = __u(sensor_features_) assert torch.allclose(torch.norm(sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\ "normalization has no effect on you huh." if hyp.do_eval_recall: results['sensor_features'] = sensor_features_ results['sensor_depths'] = input_sensor_depths results['object_img'] = rgb_camRs results['sensor_imgs'] = __p(sensor_rgbs) # if moco is used do the same procedure as above but with a different network # if hyp.do_moc or hyp.do_eval_recall: # 1. Pass all the sensor depth images through the key network key_input_sensor_depths = copy.deepcopy(__p(sensor_depths)) # bx1024x1x16x16->(2048x1x16x16) self.key_touch_featnet.eval() with torch.no_grad(): key_sensor_features_ = self.key_touch_featnet(key_input_sensor_depths) key_sensor_features_ = l2_normalize(key_sensor_features_, dim=1) key_sensor_features = __u(key_sensor_features_) assert torch.allclose(torch.norm(key_sensor_features_, dim=1), torch.Tensor([1.0]).cuda()),\ "normalization has no effect on you huh." # doing the same procedure for moco but with a different network end # # do you want to do metric learning voxel point based using visual features and sensor features if hyp.do_touch_embML and not hyp.do_touch_forward: # trial 1: I do not pass the above obtained features through some encoder decoder in 3d # So compute the location is ref_frame which the center of these depth images will occupy # at all of these locations I will sample the from the visual tensor. It forms the positive pairs # negatives are simply everything except the positive sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S)) sensor_depths_centers_x = sensor_depths_centers_x.cuda() sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S)) sensor_depths_centers_y = sensor_depths_centers_y.cuda() sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W] # Next use Pixels2Camera to unproject all of these together. # merge the batch and the sequence dimension sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1) # BxHxW as required by Pixels2Camera sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1) sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1) fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_) sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y, sensor_depths_centers_z, fx, fy, x0, y0) # finally use apply4x4 to get the locations in ref_cam sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_) # NOTE: convert them to memory coordinates, the name is xyz so I presume it returns xyz but talk to ADAM sensor_depths_centers_in_mem_ = utils_vox.Ref2Mem(sensor_depths_centers_in_ref_cam_, Z2, Y2, X2) sensor_depths_centers_in_mem = sensor_depths_centers_in_mem_.reshape(hyp.B, hyp.sensor_S, -1) if debug: print('assert that you are not entering here') from IPython import embed; embed() # form a (0, 1) volume here at these locations and see if it resembles a cup dim1 = X2 * Y2 * Z2 dim2 = X2 * Y2 dim3 = X2 binary_voxel_grid = torch.zeros((hyp.B, X2, Y2, Z2)) # NOTE: Z is the leading dimension rounded_idxs = torch.round(sensor_depths_centers_in_mem) flat_idxs = dim2 * rounded_idxs[0, :, 0] + dim3 * rounded_idxs[0, :, 1] + rounded_idxs[0, :, 2] flat_idxs1 = dim2 * rounded_idxs[1, :, 0] + dim3 * rounded_idxs[1, :, 1] + rounded_idxs[1, :, 2] flat_idxs1 = flat_idxs1 + dim1 flat_idxs1 = flat_idxs1.long() flat_idxs = flat_idxs.long() flattened_grid = binary_voxel_grid.flatten() flattened_grid[flat_idxs] = 1. flattened_grid[flat_idxs1] = 1. binary_voxel_grid = flattened_grid.view(B, X2, Y2, Z2) assert binary_voxel_grid[0].sum() == len(torch.unique(flat_idxs)), "some indexes are missed here" assert binary_voxel_grid[1].sum() == len(torch.unique(flat_idxs1)), "some indexes are missed here" # o3d.io.write_voxel_grid("forward_pass_save/grid0.ply", binary_voxel_grid[0]) # o3d.io.write_voxel_grid("forward_pass_save/grid1.ply", binary_voxel_grid[0]) # need to save these voxels save_voxel(binary_voxel_grid[0].cpu().numpy(), "forward_pass_save/grid0.binvox") save_voxel(binary_voxel_grid[1].cpu().numpy(), "forward_pass_save/grid1.binvox") from IPython import embed; embed() # use grid sample to get the visual touch tensor at these locations, NOTE: visual tensor features shape is (B, C, N) visual_tensor_features = utils_samp.bilinear_sample3D(featRs, sensor_depths_centers_in_mem[:, :, 0], sensor_depths_centers_in_mem[:, :, 1], sensor_depths_centers_in_mem[:, :, 2]) visual_feature_tensor = visual_tensor_features.permute(0, 2, 1) # pack it visual_feature_tensor_ = __p(visual_feature_tensor) C = list(visual_feature_tensor.shape)[-1] print('C=', C) # do the metric learning this is the same as before. # the code is basically copied from embnet3d.py but some changes are being made very minor emb_vec = torch.stack((sensor_features_, visual_feature_tensor_), dim=1).view(B*self.num_samples*self.batch_k, C) y = torch.stack([torch.range(0,self.num_samples*B-1), torch.range(0,self.num_samples*B-1)], dim=1).view(self.num_samples*B*self.batch_k) a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec) # I need to write my own version of margin loss since the negatives and anchors may not be same dim d_ap = torch.sqrt(torch.sum((positives - anchors)**2, dim=1) + 1e-8) pos_loss = torch.clamp(d_ap - beta + self._margin, min=0.0) # TODO: expand the dims of anchors and tile them and compute the negative loss # do the pair count where you average by contributors only # this is your total loss # Further idea is to check what volumetric locations do each of the depth images corresponds to # unproject the entire depth image and convert to ref. and then sample. if hyp.do_touch_forward: ## ... Begin code for getting crops from visual memory ... ## sensor_depths_centers_x = center_sensor_W * torch.ones((hyp.B, hyp.sensor_S)) sensor_depths_centers_x = sensor_depths_centers_x.cuda() sensor_depths_centers_y = center_sensor_H * torch.ones((hyp.B, hyp.sensor_S)) sensor_depths_centers_y = sensor_depths_centers_y.cuda() sensor_depths_centers_z = sensor_depths[:, :, 0, center_sensor_H, center_sensor_W] # Next use Pixels2Camera to unproject all of these together. # merge the batch and the sequence dimension sensor_depths_centers_x = sensor_depths_centers_x.reshape(-1, 1, 1) sensor_depths_centers_y = sensor_depths_centers_y.reshape(-1, 1, 1) sensor_depths_centers_z = sensor_depths_centers_z.reshape(-1, 1, 1) fx, fy, x0, y0 = utils_geom.split_intrinsics(sensor_pix_T_cams_) sensor_depths_centers_in_camXs_ = utils_geom.Pixels2Camera(sensor_depths_centers_x, sensor_depths_centers_y, sensor_depths_centers_z, fx, fy, x0, y0) sensor_depths_centers_in_world_ = utils_geom.apply_4x4(sensor_origin_T_camXs_, sensor_depths_centers_in_camXs_) # not used by the algorithm ## this will be later used for visualization hence saving it here for now sensor_depths_centers_in_ref_cam_ = utils_geom.apply_4x4(sensor_camRs_T_camXs_, sensor_depths_centers_in_camXs_) # not used by the algorithm sensor_depths_centers_in_camXs = __u(sensor_depths_centers_in_camXs_).squeeze(2) # There has to be a better way to do this, for each of the cameras in the batch I want a box of size (ch, cw, cd) # TODO: rotation is the deviation of the box from the axis aligned do I want this tB, tN, _ = list(sensor_depths_centers_in_camXs.shape) # 2, 512, _ boxlist = torch.zeros(tB, tN, 9) # 2, 512, 9 boxlist[:, :, :3] = sensor_depths_centers_in_camXs # this lies on the object boxlist[:, :, 3:6] = torch.FloatTensor([hyp.contextW, hyp.contextH, hyp.contextD]) # convert the boxlist to lrtlist and to cuda # the rt here transforms the from box coordinates to camera coordinates box_lrtlist = utils_geom.convert_boxlist_to_lrtlist(boxlist) # Now I will use crop_zoom_from_mem functionality to get the features in each of the boxes # I will do it for each of the box separately as required by the api context_grid_list = list() for m in range(box_lrtlist.shape[1]): curr_box = box_lrtlist[:, m, :] context_grid = utils_vox.crop_zoom_from_mem(featRs, curr_box, 8, 8, 8, sensor_camRs_T_camXs[:, m, :, :]) context_grid_list.append(context_grid) context_grid_list = torch.stack(context_grid_list, dim=1) context_grid_list_ = __p(context_grid_list) ## ... till here I believe I have not introduced any randomness, so the points are still in ## ... End code for getting crops around this center of certain height, width and depth ... ## ## ... Begin code for passing the context grid through 3D CNN to obtain a vector ... ## sensor_cam_locs = feed['sensor_locs'] # these are in origin coordinates sensor_cam_quats = feed['sensor_quats'] # this too in in world_coordinates sensor_cam_locs_ = __p(sensor_cam_locs) sensor_cam_quats_ = __p(sensor_cam_quats) sensor_cam_locs_in_R_ = utils_geom.apply_4x4(sensor_camRs_T_origin_, sensor_cam_locs_.unsqueeze(1)).squeeze(1) # TODO TODO TODO confirm that this is right? TODO TODO TODO get_r_mat = lambda cam_quat: transformations.quaternion_matrix_py(cam_quat) rot_mat_Xs_ = torch.from_numpy(np.stack(list(map(get_r_mat, sensor_cam_quats_.cpu().numpy())))).to(sensor_cam_locs_.device).float() rot_mat_Rs_ = torch.bmm(sensor_camRs_T_origin_, rot_mat_Xs_) get_quat = lambda r_mat: transformations.quaternion_from_matrix_py(r_mat) sensor_quats_in_R_ = torch.from_numpy(np.stack(list(map(get_quat, rot_mat_Rs_.cpu().numpy())))).to(sensor_cam_locs_.device).float() pred_features_ = self.context_net(context_grid_list_,\ sensor_cam_locs_in_R_, sensor_quats_in_R_) # normalize pred_features_ = l2_normalize(pred_features_, dim=1) pred_features = __u(pred_features_) # if doing moco I have to pass the inputs through the key(slow) network as well # if hyp.do_moc or hyp.do_eval_recall: key_context_grid_list_ = copy.deepcopy(context_grid_list_) key_sensor_cam_locs_in_R_ = copy.deepcopy(sensor_cam_locs_in_R_) key_sensor_quats_in_R_ = copy.deepcopy(sensor_quats_in_R_) self.key_context_net.eval() with torch.no_grad(): key_pred_features_ = self.key_context_net(key_context_grid_list_,\ key_sensor_cam_locs_in_R_, key_sensor_quats_in_R_) # normalize, normalization is very important why though key_pred_features_ = l2_normalize(key_pred_features_, dim=1) key_pred_features = __u(key_pred_features_) # end passing of the input through the slow network this is necessary for moco # ## ... End code for passing the context grid through 3D CNN to obtain a vector ... ## ## ... Begin code for doing metric learning between pred_features and sensor features ... ## # 1. Subsample both based on the number of positive samples if hyp.do_touch_embML: assert(hyp.do_touch_forward) assert(hyp.do_touch_feat) perm = torch.randperm(len(pred_features_)) ## 1024 chosen_sensor_feats_ = sensor_features_[perm[:self.num_pos_samples*hyp.B]] chosen_pred_feats_ = pred_features_[perm[:self.num_pos_samples*B]] # 2. form the emb_vec and get pos and negative samples for the batch emb_vec = torch.stack((chosen_sensor_feats_, chosen_pred_feats_), dim=1).view(hyp.B*self.num_pos_samples*self.batch_k, -1) y = torch.stack([torch.range(0, self.num_pos_samples*B-1), torch.range(0, self.num_pos_samples*B-1)],\ dim=1).view(B*self.num_pos_samples*self.batch_k) # (0, 0, 1, 1, ..., 255, 255) a_indices, anchors, positives, negatives, _ = self.sampler(emb_vec) # 3. Compute the loss, ML loss and the l2 distance betwee the embeddings margin_loss, _ = self.criterion(anchors, positives, negatives, self.beta, y[a_indices]) total_loss = utils_misc.add_loss('embtouch/emb_touch_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer) # the l2 loss between the embeddings l2_loss = torch.nn.functional.mse_loss(chosen_sensor_feats_, chosen_pred_feats_) total_loss = utils_misc.add_loss('embtouch/emb_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer) ## ... End code for doing metric learning between pred_features and sensor_features ... ## ## ... Begin code for doing moc inspired ML between pred_features and sensor_features ... ## if hyp.do_moc and moc_init_done: moc_loss = self.moc_ml_net(sensor_features_, key_sensor_features_,\ pred_features_, key_pred_features_, summ_writer) total_loss += moc_loss ## ... End code for doing moc inspired ML between pred_features and sensor_feature ... ## ## ... add code for filling up results needed for eval recall ... ## if hyp.do_eval_recall and moc_init_done: results['context_features'] = pred_features_ results['sensor_depth_centers_in_world'] = sensor_depths_centers_in_world_ results['sensor_depths_centers_in_ref_cam'] = sensor_depths_centers_in_ref_cam_ results['object_name'] = feed['object_name'] # I will do precision recall here at different recall values and summarize it using tensorboard recalls = [1, 5, 10, 50, 100, 200] # also should not include any gradients because of this # fast_sensor_emb_e = sensor_features_ # fast_context_emb_e = pred_features_ # slow_sensor_emb_g = key_sensor_features_ # slow_context_emb_g = key_context_features_ fast_sensor_emb_e = sensor_features_.clone().detach() fast_context_emb_e = pred_features_.clone().detach() # I will do multiple eval recalls here slow_sensor_emb_g = key_sensor_features_.clone().detach() slow_context_emb_g = key_pred_features_.clone().detach() # assuming the above thing goes well fast_sensor_emb_e = fast_sensor_emb_e.cpu().numpy() fast_context_emb_e = fast_context_emb_e.cpu().numpy() slow_sensor_emb_g = slow_sensor_emb_g.cpu().numpy() slow_context_emb_g = slow_context_emb_g.cpu().numpy() # now also move the vis to numpy and plot it using matplotlib vis_e = __p(sensor_rgbs) vis_g = __p(sensor_rgbs) np_vis_e = vis_e.cpu().detach().numpy() np_vis_e = np.transpose(np_vis_e, [0, 2, 3, 1]) np_vis_g = vis_g.cpu().detach().numpy() np_vis_g = np.transpose(np_vis_g, [0, 2, 3, 1]) # bring it back to original color np_vis_g = ((np_vis_g+0.5) * 255).astype(np.uint8) np_vis_e = ((np_vis_e+0.5) * 255).astype(np.uint8) # now compare fast_sensor_emb_e with slow_context_emb_g # since I am doing positive against this fast_sensor_emb_e_list = [fast_sensor_emb_e, np_vis_e] slow_context_emb_g_list = [slow_context_emb_g, np_vis_g] prec, vis, chosen_inds_and_neighbors_inds = compute_precision( fast_sensor_emb_e_list, slow_context_emb_g_list, recalls=recalls ) # finally plot the nearest neighbour retrieval and move ahead if feed['global_step'] % 1 == 0: plot_nearest_neighbours(vis, step=feed['global_step'], save_dir='/home/gauravp/eval_results', name='fast_sensor_slow_context') # plot the precisions at different recalls for pr, re in enumerate(recalls): summ_writer.summ_scalar(f'evrefast_sensor_slow_context/recall@{re}',\ prec[pr]) # now compare fast_context_emb_e with slow_sensor_emb_g fast_context_emb_e_list = [fast_context_emb_e, np_vis_e] slow_sensor_emb_g_list = [slow_sensor_emb_g, np_vis_g] prec, vis, chosen_inds_and_neighbors_inds = compute_precision( fast_context_emb_e_list, slow_sensor_emb_g_list, recalls=recalls ) if feed['global_step'] % 1 == 0: plot_nearest_neighbours(vis, step=feed['global_step'], save_dir='/home/gauravp/eval_results', name='fast_context_slow_sensor') # plot the precisions at different recalls for pr, re in enumerate(recalls): summ_writer.summ_scalar(f'evrefast_context_slow_sensor/recall@{re}',\ prec[pr]) # now finally compare both the fast, I presume we want them to go closer too fast_sensor_list = [fast_sensor_emb_e, np_vis_e] fast_context_list = [fast_context_emb_e, np_vis_g] prec, vis, chosen_inds_and_neighbors_inds = compute_precision( fast_sensor_list, fast_context_list, recalls=recalls ) if feed['global_step'] % 1 == 0: plot_nearest_neighbours(vis, step=feed['global_step'], save_dir='/home/gauravp/eval_results', name='fast_sensor_fast_context') for pr, re in enumerate(recalls): summ_writer.summ_scalar(f'evrefast_sensor_fast_context/recall@{re}',\ prec[pr]) ## ... done code for filling up results needed for eval recall ... ## summ_writer.summ_scalar('loss', total_loss.cpu().item()) return total_loss, results, [key_sensor_features_, key_pred_features_]
def forward(self, feed): results = dict() if 'log_freq' not in feed.keys(): feed['log_freq'] = None start_time = time.time() summ_writer = utils_improc.Summ_writer(writer=feed['writer'], global_step=feed['global_step'], set_name=feed['set_name'], log_freq=feed['log_freq'], fps=8) writer = feed['writer'] global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) __pb = lambda x: utils_basic.pack_boxdim(x, hyp.N) __ub = lambda x: utils_basic.unpack_boxdim(x, hyp.N) if hyp.aug_object_ent_dis: __pb_a = lambda x: utils_basic.pack_boxdim( x, hyp.max_obj_aug + hyp.max_obj_aug_dis) __ub_a = lambda x: utils_basic.unpack_boxdim( x, hyp.max_obj_aug + hyp.max_obj_aug_dis) else: __pb_a = lambda x: utils_basic.pack_boxdim(x, hyp.max_obj_aug) __ub_a = lambda x: utils_basic.unpack_boxdim(x, hyp.max_obj_aug) B, H, W, V, S, N = hyp.B, hyp.H, hyp.W, hyp.V, hyp.S, hyp.N PH, PW = hyp.PH, hyp.PW K = hyp.K BOX_SIZE = hyp.BOX_SIZE Z, Y, X = hyp.Z, hyp.Y, hyp.X Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4) D = 9 tids = torch.from_numpy(np.reshape(np.arange(B * N), [B, N])) rgb_camXs = feed["rgb_camXs_raw"] pix_T_cams = feed["pix_T_cams_raw"] camRs_T_origin = feed["camR_T_origin_raw"] origin_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_origin))) origin_T_camXs = feed["origin_T_camXs_raw"] camX0_T_camXs = utils_geom.get_camM_T_camXs(origin_T_camXs, ind=0) camRs_T_camXs = __u( torch.matmul(utils_geom.safe_inverse(__p(origin_T_camRs)), __p(origin_T_camXs))) camXs_T_camRs = __u(utils_geom.safe_inverse(__p(camRs_T_camXs))) camX0_T_camRs = camXs_T_camRs[:, 0] camX1_T_camRs = camXs_T_camRs[:, 1] camR_T_camX0 = utils_geom.safe_inverse(camX0_T_camRs) xyz_camXs = feed["xyz_camXs_raw"] depth_camXs_, valid_camXs_ = utils_geom.create_depth_image( __p(pix_T_cams), __p(xyz_camXs), H, W) dense_xyz_camXs_ = utils_geom.depth2pointcloud(depth_camXs_, __p(pix_T_cams)) xyz_camRs = __u( utils_geom.apply_4x4(__p(camRs_T_camXs), __p(xyz_camXs))) xyz_camX0s = __u( utils_geom.apply_4x4(__p(camX0_T_camXs), __p(xyz_camXs))) occXs = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z, Y, X)) occXs_to_Rs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, occXs) occXs_to_Rs_45 = cross_corr.rotate_tensor_along_y_axis(occXs_to_Rs, 45) occXs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camXs), Z2, Y2, X2)) occRs_half = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z2, Y2, X2)) occX0s_half = __u(utils_vox.voxelize_xyz(__p(xyz_camX0s), Z2, Y2, X2)) unpXs = __u( utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z, Y, X, __p(pix_T_cams))) unpXs_half = __u( utils_vox.unproject_rgb_to_mem(__p(rgb_camXs), Z2, Y2, X2, __p(pix_T_cams))) unpX0s_half = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z2, Y2, X2, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camX0_T_camXs))))) unpRs = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z, Y, X, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camRs_T_camXs))))) unpRs_half = __u( utils_vox.unproject_rgb_to_mem( __p(rgb_camXs), Z2, Y2, X2, utils_basic.matmul2( __p(pix_T_cams), utils_geom.safe_inverse(__p(camRs_T_camXs))))) dense_xyz_camRs_ = utils_geom.apply_4x4(__p(camRs_T_camXs), dense_xyz_camXs_) inbound_camXs_ = utils_vox.get_inbounds(dense_xyz_camRs_, Z, Y, X).float() inbound_camXs_ = torch.reshape(inbound_camXs_, [B * S, 1, H, W]) depth_camXs = __u(depth_camXs_) valid_camXs = __u(valid_camXs_) * __u(inbound_camXs_) summ_writer.summ_oneds('2D_inputs/depth_camXs', torch.unbind(depth_camXs, dim=1), maxdepth=21.0) summ_writer.summ_oneds('2D_inputs/valid_camXs', torch.unbind(valid_camXs, dim=1)) summ_writer.summ_rgbs('2D_inputs/rgb_camXs', torch.unbind(rgb_camXs, dim=1)) summ_writer.summ_occs('3D_inputs/occXs', torch.unbind(occXs, dim=1)) summ_writer.summ_unps('3D_inputs/unpXs', torch.unbind(unpXs, dim=1), torch.unbind(occXs, dim=1)) occRs = __u(utils_vox.voxelize_xyz(__p(xyz_camRs), Z, Y, X)) if hyp.do_eval_boxes: if hyp.dataset_name == "clevr_vqa": gt_boxes_origin_corners = feed['gt_box'] gt_scores_origin = feed['gt_scores'].detach().cpu().numpy() classes = feed['classes'] scores = gt_scores_origin tree_seq_filename = feed['tree_seq_filename'] gt_boxes_origin = nlu.get_ends_of_corner( gt_boxes_origin_corners) gt_boxes_origin_end = torch.reshape(gt_boxes_origin, [hyp.B, hyp.N, 2, 3]) gt_boxes_origin_theta = nlu.get_alignedboxes2thetaformat( gt_boxes_origin_end) gt_boxes_origin_corners = utils_geom.transform_boxes_to_corners( gt_boxes_origin_theta) gt_boxesR_corners = __ub( utils_geom.apply_4x4(camRs_T_origin[:, 0], __pb(gt_boxes_origin_corners))) gt_boxesR_theta = utils_geom.transform_corners_to_boxes( gt_boxesR_corners) gt_boxesR_end = nlu.get_ends_of_corner(gt_boxesR_corners) else: tree_seq_filename = feed['tree_seq_filename'] tree_filenames = [ join(hyp.root_dataset, i) for i in tree_seq_filename if i != "invalid_tree" ] invalid_tree_filenames = [ join(hyp.root_dataset, i) for i in tree_seq_filename if i == "invalid_tree" ] num_empty = len(invalid_tree_filenames) trees = [pickle.load(open(i, "rb")) for i in tree_filenames] len_valid = len(trees) if len_valid > 0: gt_boxesR, scores, classes = nlu.trees_rearrange(trees) if num_empty > 0: gt_boxesR = np.concatenate([ gt_boxesR, empty_gt_boxesR ]) if len_valid > 0 else empty_gt_boxesR scores = np.concatenate([ scores, empty_scores ]) if len_valid > 0 else empty_scores classes = np.concatenate([ classes, empty_classes ]) if len_valid > 0 else empty_classes gt_boxesR = torch.from_numpy( gt_boxesR).cuda().float() # torch.Size([2, 3, 6]) gt_boxesR_end = torch.reshape(gt_boxesR, [hyp.B, hyp.N, 2, 3]) gt_boxesR_theta = nlu.get_alignedboxes2thetaformat( gt_boxesR_end) #torch.Size([2, 3, 9]) gt_boxesR_corners = utils_geom.transform_boxes_to_corners( gt_boxesR_theta) class_names_ex_1 = "_".join(classes[0]) summ_writer.summ_text('eval_boxes/class_names', class_names_ex_1) gt_boxesRMem_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z2, Y2, X2)) gt_boxesRMem_end = nlu.get_ends_of_corner(gt_boxesRMem_corners) gt_boxesRMem_theta = utils_geom.transform_corners_to_boxes( gt_boxesRMem_corners) gt_boxesRUnp_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesR_corners), Z, Y, X)) gt_boxesRUnp_end = nlu.get_ends_of_corner(gt_boxesRUnp_corners) gt_boxesX0_corners = __ub( utils_geom.apply_4x4(camX0_T_camRs, __pb(gt_boxesR_corners))) gt_boxesX0Mem_corners = __ub( utils_vox.Ref2Mem(__pb(gt_boxesX0_corners), Z2, Y2, X2)) gt_boxesX0Mem_theta = utils_geom.transform_corners_to_boxes( gt_boxesX0Mem_corners) gt_boxesX0Mem_end = nlu.get_ends_of_corner(gt_boxesX0Mem_corners) gt_boxesX0_end = nlu.get_ends_of_corner(gt_boxesX0_corners) gt_cornersX0_pix = __ub( utils_geom.apply_pix_T_cam(pix_T_cams[:, 0], __pb(gt_boxesX0_corners))) rgb_camX0 = rgb_camXs[:, 0] rgb_camX1 = rgb_camXs[:, 1] summ_writer.summ_box_by_corners('eval_boxes/gt_boxescamX0', rgb_camX0, gt_boxesX0_corners, torch.from_numpy(scores), tids, pix_T_cams[:, 0]) unps_vis = utils_improc.get_unps_vis(unpX0s_half, occX0s_half) unp_vis = torch.mean(unps_vis, dim=1) unps_visRs = utils_improc.get_unps_vis(unpRs_half, occRs_half) unp_visRs = torch.mean(unps_visRs, dim=1) unps_visRs_full = utils_improc.get_unps_vis(unpRs, occRs) unp_visRs_full = torch.mean(unps_visRs_full, dim=1) summ_writer.summ_box_mem_on_unp('eval_boxes/gt_boxesR_mem', unp_visRs, gt_boxesRMem_end, scores, tids) unpX0s_half = torch.mean(unpX0s_half, dim=1) unpX0s_half = nlu.zero_out(unpX0s_half, gt_boxesX0Mem_end, scores) occX0s_half = torch.mean(occX0s_half, dim=1) occX0s_half = nlu.zero_out(occX0s_half, gt_boxesX0Mem_end, scores) summ_writer.summ_unp('3D_inputs/unpX0s', unpX0s_half, occX0s_half) if hyp.do_feat: featXs_input = torch.cat([occXs, occXs * unpXs], dim=2) featXs_input_ = __p(featXs_input) freeXs_ = utils_vox.get_freespace(__p(xyz_camXs), __p(occXs_half)) freeXs = __u(freeXs_) visXs = torch.clamp(occXs_half + freeXs, 0.0, 1.0) mask_ = None if (type(mask_) != type(None)): assert (list(mask_.shape)[2:5] == list( featXs_input_.shape)[2:5]) featXs_, feat_loss = self.featnet(featXs_input_, summ_writer, mask=__p(occXs)) #mask_) total_loss += feat_loss validXs = torch.ones_like(visXs) _validX00 = validXs[:, 0:1] _validX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:], validXs[:, 1:]) validX0s = torch.cat([_validX00, _validX01], dim=1) validRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, validXs) visRs = utils_vox.apply_4x4s_to_voxs(camRs_T_camXs, visXs) featXs = __u(featXs_) _featX00 = featXs[:, 0:1] _featX01 = utils_vox.apply_4x4s_to_voxs(camX0_T_camXs[:, 1:], featXs[:, 1:]) featX0s = torch.cat([_featX00, _featX01], dim=1) emb3D_e = torch.mean(featX0s[:, 1:], dim=1) vis3D_e_R = torch.max(visRs[:, 1:], dim=1)[0] emb3D_g = featX0s[:, 0] vis3D_g_R = visRs[:, 0] validR_combo = torch.min(validRs, dim=1).values summ_writer.summ_feats('3D_feats/featXs_input', torch.unbind(featXs_input, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featXs_output', torch.unbind(featXs, dim=1), valids=torch.unbind(validXs, dim=1), pca=True) summ_writer.summ_feats('3D_feats/featX0s_output', torch.unbind(featX0s, dim=1), valids=torch.unbind( torch.ones_like(validRs), dim=1), pca=True) summ_writer.summ_feats('3D_feats/validRs', torch.unbind(validRs, dim=1), pca=False) summ_writer.summ_feat('3D_feats/vis3D_e_R', vis3D_e_R, pca=False) summ_writer.summ_feat('3D_feats/vis3D_g_R', vis3D_g_R, pca=False) if hyp.do_munit: object_classes, filenames = nlu.create_object_classes( classes, [tree_seq_filename, tree_seq_filename], scores) if hyp.do_munit_fewshot: emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e) emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g) emb3D_R = emb3D_e_R emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors( [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end, scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE]) emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2 content, style = self.munitnet.net.gen_a.encode(emb3D_R_object) objects_taken, _ = self.munitnet.net.gen_a.decode( content, style) styles = style contents = content elif hyp.do_3d_style_munit: emb3D_e_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_e) emb3D_g_R = utils_vox.apply_4x4_to_vox(camR_T_camX0, emb3D_g) emb3D_R = emb3D_e_R # st() emb3D_e_R_object, emb3D_g_R_object, validR_combo_object = nlu.create_object_tensors( [emb3D_e_R, emb3D_g_R], [validR_combo], gt_boxesRMem_end, scores, [BOX_SIZE, BOX_SIZE, BOX_SIZE]) emb3D_R_object = (emb3D_e_R_object + emb3D_g_R_object) / 2 camX1_T_R = camXs_T_camRs[:, 1] camX0_T_R = camXs_T_camRs[:, 0] assert hyp.B == 2 assert emb3D_e_R_object.shape[0] == 2 munit_loss, sudo_input_0, sudo_input_1, recon_input_0, recon_input_1, sudo_input_0_cycle, sudo_input_1_cycle, styles, contents, adin = self.munitnet( emb3D_R_object[0:1], emb3D_R_object[1:2]) if hyp.store_content_style_range: if self.max_content == None: self.max_content = torch.zeros_like( contents[0][0]).cuda() - 100000000 if self.min_content == None: self.min_content = torch.zeros_like( contents[0][0]).cuda() + 100000000 if self.max_style == None: self.max_style = torch.zeros_like( styles[0][0]).cuda() - 100000000 if self.min_style == None: self.min_style = torch.zeros_like( styles[0][0]).cuda() + 100000000 self.max_content = torch.max( torch.max(self.max_content, contents[0][0]), contents[1][0]) self.min_content = torch.min( torch.min(self.min_content, contents[0][0]), contents[1][0]) self.max_style = torch.max( torch.max(self.max_style, styles[0][0]), styles[1][0]) self.min_style = torch.min( torch.min(self.min_style, styles[0][0]), styles[1][0]) data_to_save = { 'max_content': self.max_content.cpu().numpy(), 'min_content': self.min_content.cpu().numpy(), 'max_style': self.max_style.cpu().numpy(), 'min_style': self.min_style.cpu().numpy() } with open('content_style_range.p', 'wb') as f: pickle.dump(data_to_save, f) elif hyp.is_contrastive_examples: if hyp.normalize_contrast: content0 = (contents[0] - self.min_content) / ( self.max_content - self.min_content + 1e-5) content1 = (contents[1] - self.min_content) / ( self.max_content - self.min_content + 1e-5) style0 = (styles[0] - self.min_style) / ( self.max_style - self.min_style + 1e-5) style1 = (styles[1] - self.min_style) / ( self.max_style - self.min_style + 1e-5) else: content0 = contents[0] content1 = contents[1] style0 = styles[0] style1 = styles[1] # euclid_dist_content = torch.sum(torch.sqrt((content0 - content1)**2))/torch.prod(torch.tensor(content0.shape)) # euclid_dist_style = torch.sum(torch.sqrt((style0-style1)**2))/torch.prod(torch.tensor(style0.shape)) euclid_dist_content = (content0 - content1).norm(2) / ( content0.numel()) euclid_dist_style = (style0 - style1).norm(2) / (style0.numel()) content_0_pooled = torch.mean( content0.reshape(list(content0.shape[:2]) + [-1]), dim=-1) content_1_pooled = torch.mean( content1.reshape(list(content1.shape[:2]) + [-1]), dim=-1) euclid_dist_content_pooled = (content_0_pooled - content_1_pooled).norm(2) / ( content_0_pooled.numel()) content_0_normalized = content0 / content0.norm() content_1_normalized = content1 / content1.norm() style_0_normalized = style0 / style0.norm() style_1_normalized = style1 / style1.norm() content_0_pooled_normalized = content_0_pooled / content_0_pooled.norm( ) content_1_pooled_normalized = content_1_pooled / content_1_pooled.norm( ) cosine_dist_content = torch.sum(content_0_normalized * content_1_normalized) cosine_dist_style = torch.sum(style_0_normalized * style_1_normalized) cosine_dist_content_pooled = torch.sum( content_0_pooled_normalized * content_1_pooled_normalized) print("euclid dist [content, pooled-content, style]: ", euclid_dist_content, euclid_dist_content_pooled, euclid_dist_style) print("cosine sim [content, pooled-content, style]: ", cosine_dist_content, cosine_dist_content_pooled, cosine_dist_style) if hyp.run_few_shot_on_munit: if (global_step % 300) == 1 or (global_step % 300) == 0: wrong = False try: precision_style = float(self.tp_style) / self.all_style precision_content = float( self.tp_content) / self.all_content except ZeroDivisionError: wrong = True if not wrong: summ_writer.summ_scalar( 'precision/unsupervised_precision_style', precision_style) summ_writer.summ_scalar( 'precision/unsupervised_precision_content', precision_content) # st() self.embed_list_style = defaultdict(lambda: []) self.embed_list_content = defaultdict(lambda: []) self.tp_style = 0 self.all_style = 0 self.tp_content = 0 self.all_content = 0 self.check = False elif not self.check and not nlu.check_fill_dict( self.embed_list_content, self.embed_list_style): print("Filling \n") for index, class_val in enumerate(object_classes): if hyp.dataset_name == "clevr_vqa": class_val_content, class_val_style = class_val.split( "/") else: class_val_content, class_val_style = [ class_val.split("/")[0], class_val.split("/")[0] ] print(len(self.embed_list_style.keys()), "style class", len(self.embed_list_content), "content class", self.embed_list_content.keys()) if len(self.embed_list_style[class_val_style] ) < hyp.few_shot_nums: self.embed_list_style[class_val_style].append( styles[index].squeeze()) if len(self.embed_list_content[class_val_content] ) < hyp.few_shot_nums: if hyp.avg_3d: content_val = contents[index] content_val = torch.mean(content_val.reshape( [content_val.shape[1], -1]), dim=-1) # st() self.embed_list_content[ class_val_content].append(content_val) else: self.embed_list_content[ class_val_content].append( contents[index].reshape([-1])) else: self.check = True try: print(float(self.tp_content) / self.all_content) print(float(self.tp_style) / self.all_style) except Exception as e: pass average = True if average: for key, val in self.embed_list_style.items(): if isinstance(val, type([])): self.embed_list_style[key] = torch.mean( torch.stack(val, dim=0), dim=0) for key, val in self.embed_list_content.items(): if isinstance(val, type([])): self.embed_list_content[key] = torch.mean( torch.stack(val, dim=0), dim=0) else: for key, val in self.embed_list_style.items(): if isinstance(val, type([])): self.embed_list_style[key] = torch.stack(val, dim=0) for key, val in self.embed_list_content.items(): if isinstance(val, type([])): self.embed_list_content[key] = torch.stack( val, dim=0) for index, class_val in enumerate(object_classes): class_val = class_val if hyp.dataset_name == "clevr_vqa": class_val_content, class_val_style = class_val.split( "/") else: class_val_content, class_val_style = [ class_val.split("/")[0], class_val.split("/")[0] ] style_val = styles[index].squeeze().unsqueeze(0) if not average: embed_list_val_style = torch.cat(list( self.embed_list_style.values()), dim=0) embed_list_key_style = list( np.repeat( np.expand_dims( list(self.embed_list_style.keys()), 1), hyp.few_shot_nums, 1).reshape([-1])) else: embed_list_val_style = torch.stack(list( self.embed_list_style.values()), dim=0) embed_list_key_style = list( self.embed_list_style.keys()) embed_list_val_style = utils_basic.l2_normalize( embed_list_val_style, dim=1).permute(1, 0) style_val = utils_basic.l2_normalize(style_val, dim=1) scores_styles = torch.matmul(style_val, embed_list_val_style) index_key = torch.argmax(scores_styles, dim=1).squeeze() selected_class_style = embed_list_key_style[index_key] self.styles_prediction[class_val_style].append( selected_class_style) if class_val_style == selected_class_style: self.tp_style += 1 self.all_style += 1 if hyp.avg_3d: content_val = contents[index] content_val = torch.mean(content_val.reshape( [content_val.shape[1], -1]), dim=-1).unsqueeze(0) else: content_val = contents[index].reshape( [-1]).unsqueeze(0) if not average: embed_list_val_content = torch.cat(list( self.embed_list_content.values()), dim=0) embed_list_key_content = list( np.repeat( np.expand_dims( list(self.embed_list_content.keys()), 1), hyp.few_shot_nums, 1).reshape([-1])) else: embed_list_val_content = torch.stack(list( self.embed_list_content.values()), dim=0) embed_list_key_content = list( self.embed_list_content.keys()) embed_list_val_content = utils_basic.l2_normalize( embed_list_val_content, dim=1).permute(1, 0) content_val = utils_basic.l2_normalize(content_val, dim=1) scores_content = torch.matmul(content_val, embed_list_val_content) index_key = torch.argmax(scores_content, dim=1).squeeze() selected_class_content = embed_list_key_content[ index_key] self.content_prediction[class_val_content].append( selected_class_content) if class_val_content == selected_class_content: self.tp_content += 1 self.all_content += 1 # st() munit_loss = hyp.munit_loss_weight * munit_loss recon_input_obj = torch.cat([recon_input_0, recon_input_1], dim=0) recon_emb3D_R = nlu.update_scene_with_objects( emb3D_R, recon_input_obj, gt_boxesRMem_end, scores) sudo_input_obj = torch.cat([sudo_input_0, sudo_input_1], dim=0) styled_emb3D_R = nlu.update_scene_with_objects( emb3D_R, sudo_input_obj, gt_boxesRMem_end, scores) styled_emb3D_e_X1 = utils_vox.apply_4x4_to_vox( camX1_T_R, styled_emb3D_R) styled_emb3D_e_X0 = utils_vox.apply_4x4_to_vox( camX0_T_R, styled_emb3D_R) emb3D_e_X1 = utils_vox.apply_4x4_to_vox(camX1_T_R, recon_emb3D_R) emb3D_e_X0 = utils_vox.apply_4x4_to_vox(camX0_T_R, recon_emb3D_R) emb3D_e_X1_og = utils_vox.apply_4x4_to_vox(camX1_T_R, emb3D_R) emb3D_e_X0_og = utils_vox.apply_4x4_to_vox(camX0_T_R, emb3D_R) emb3D_R_aug_diff = torch.abs(emb3D_R - recon_emb3D_R) summ_writer.summ_feat(f'aug_feat/og', emb3D_R) summ_writer.summ_feat(f'aug_feat/og_gen', recon_emb3D_R) summ_writer.summ_feat(f'aug_feat/og_aug_diff', emb3D_R_aug_diff) if hyp.cycle_style_view_loss: sudo_input_obj_cycle = torch.cat( [sudo_input_0_cycle, sudo_input_1_cycle], dim=0) styled_emb3D_R_cycle = nlu.update_scene_with_objects( emb3D_R, sudo_input_obj_cycle, gt_boxesRMem_end, scores) styled_emb3D_e_X0_cycle = utils_vox.apply_4x4_to_vox( camX0_T_R, styled_emb3D_R_cycle) styled_emb3D_e_X1_cycle = utils_vox.apply_4x4_to_vox( camX1_T_R, styled_emb3D_R_cycle) summ_writer.summ_scalar('munit_loss', munit_loss.cpu().item()) total_loss += munit_loss if hyp.do_occ and hyp.occ_do_cheap: occX0_sup, freeX0_sup, _, freeXs = utils_vox.prep_occs_supervision( camX0_T_camXs, xyz_camXs, Z2, Y2, X2, agg=True) summ_writer.summ_occ('occ_sup/occ_sup', occX0_sup) summ_writer.summ_occ('occ_sup/free_sup', freeX0_sup) summ_writer.summ_occs('occ_sup/freeXs_sup', torch.unbind(freeXs, dim=1)) summ_writer.summ_occs('occ_sup/occXs_sup', torch.unbind(occXs_half, dim=1)) occ_loss, occX0s_pred_ = self.occnet( torch.mean(featX0s[:, 1:], dim=1), occX0_sup, freeX0_sup, torch.max(validX0s[:, 1:], dim=1)[0], summ_writer) occX0s_pred = __u(occX0s_pred_) total_loss += occ_loss if hyp.do_view: assert (hyp.do_feat) PH, PW = hyp.PH, hyp.PW sy = float(PH) / float(hyp.H) sx = float(PW) / float(hyp.W) assert (sx == 0.5) # else we need a fancier downsampler assert (sy == 0.5) projpix_T_cams = __u( utils_geom.scale_intrinsics(__p(pix_T_cams), sx, sy)) # st() if hyp.do_munit: feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], emb3D_e_X1, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) feat_projX00_og = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], emb3D_e_X1_og, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) # only for checking the style styled_feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], styled_emb3D_e_X1, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) if hyp.cycle_style_view_loss: styled_feat_projX00_cycle = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], styled_emb3D_e_X1_cycle, # use feat1 to predict rgb0 hyp.view_depth, PH, PW) else: feat_projX00 = utils_vox.apply_pixX_T_memR_to_voxR( projpix_T_cams[:, 0], camX0_T_camXs[:, 1], featXs[:, 1], # use feat1 to predict rgb0 hyp.view_depth, PH, PW) rgb_X00 = utils_basic.downsample(rgb_camXs[:, 0], 2) rgb_X01 = utils_basic.downsample(rgb_camXs[:, 1], 2) valid_X00 = utils_basic.downsample(valid_camXs[:, 0], 2) view_loss, rgb_e, emb2D_e = self.viewnet(feat_projX00, rgb_X00, valid_X00, summ_writer, "rgb") if hyp.do_munit: _, rgb_e, emb2D_e = self.viewnet(feat_projX00_og, rgb_X00, valid_X00, summ_writer, "rgb_og") if hyp.do_munit: styled_view_loss, styled_rgb_e, styled_emb2D_e = self.viewnet( styled_feat_projX00, rgb_X00, valid_X00, summ_writer, "recon_style") if hyp.cycle_style_view_loss: styled_view_loss_cycle, styled_rgb_e_cycle, styled_emb2D_e_cycle = self.viewnet( styled_feat_projX00_cycle, rgb_X00, valid_X00, summ_writer, "recon_style_cycle") rgb_input_1 = torch.cat( [rgb_X01[1], rgb_X01[0], styled_rgb_e[0]], dim=2) rgb_input_2 = torch.cat( [rgb_X01[0], rgb_X01[1], styled_rgb_e[1]], dim=2) complete_vis = torch.cat([rgb_input_1, rgb_input_2], dim=1) summ_writer.summ_rgb('munit/munit_recons_vis', complete_vis.unsqueeze(0)) if not hyp.do_munit: total_loss += view_loss else: if hyp.basic_view_loss: total_loss += view_loss if hyp.style_view_loss: total_loss += styled_view_loss if hyp.cycle_style_view_loss: total_loss += styled_view_loss_cycle summ_writer.summ_scalar('loss', total_loss.cpu().item()) if hyp.save_embed_tsne: for index, class_val in enumerate(object_classes): class_val_content, class_val_style = class_val.split("/") style_val = styles[index].squeeze().unsqueeze(0) self.cluster_pool.update(style_val, [class_val_style]) print(self.cluster_pool.num) if self.cluster_pool.is_full(): embeds, classes = self.cluster_pool.fetch() with open("offline_cluster" + '/%st.txt' % 'classes', 'w') as f: for index, embed in enumerate(classes): class_val = classes[index] f.write("%s\n" % class_val) f.close() with open("offline_cluster" + '/%st.txt' % 'embeddings', 'w') as f: for index, embed in enumerate(embeds): # embed = utils_basic.l2_normalize(embed,dim=0) print("writing {} embed".format(index)) embed_l_s = [str(i) for i in embed.tolist()] embed_str = '\t'.join(embed_l_s) f.write("%s\n" % embed_str) f.close() st() return total_loss, results
def forward(self, feat, obj_lrtlist_cams, obj_scorelist_s, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat.shape) K, B2, S, D = list(obj_lrtlist_cams.shape) assert (B == B2) # obj_scorelist_s is K x B x S # __p = lambda x: utils_basic.pack_seqdim(x, B) # __u = lambda x: utils_basic.unpack_seqdim(x, B) # obj_lrtlist_cams is K x B x S x 19 obj_lrtlist_cams_ = obj_lrtlist_cams.reshape(K * B, S, 19) obj_clist_cam_ = utils_geom.get_clist_from_lrtlist(obj_lrtlist_cams_) obj_clist_cam = obj_clist_cam_.reshape(K, B, S, 1, 3) # obj_clist_cam is K x B x S x 1 x 3 obj_clist_cam = obj_clist_cam.squeeze(3) # obj_clist_cam is K x B x S x 3 clist_cam = obj_clist_cam.reshape(K * B, S, 3) clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X) # this is K*B x S x 3 clist_mem = clist_mem.reshape(K, B, S, 3) energy_vol = self.conv3d(feat) # energy_vol is B x 1 x Z x Y x X summ_writer.summ_oned('pri/energy_vol', torch.mean(energy_vol, dim=3)) summ_writer.summ_histogram('pri/energy_vol_hist', energy_vol) # for k in range(K): # let's start with the first object # loglike_per_traj = self.get_traj_loglike(clist_mem[0], energy_vol) # # this is B # ce_loss = -1.0*torch.mean(loglike_per_traj) # # this is [] loglike_per_traj = self.get_trajs_loglike(clist_mem, obj_scorelist_s, energy_vol) # this is B x K valid = torch.max(obj_scorelist_s.permute(1, 0, 2), dim=2)[0] ce_loss = -1.0 * utils_basic.reduce_masked_mean( loglike_per_traj, valid) # this is [] total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss, hyp.pri_ce_coeff, summ_writer) reg_loss = torch.sum(torch.abs(energy_vol)) total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss, hyp.pri_reg_coeff, summ_writer) # smooth loss dz, dy, dx = gradient3D(energy_vol, absolute=True) smooth_vox = torch.mean(dx + dy + dx, dim=1, keepdims=True) summ_writer.summ_oned('pri/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss, smooth_loss, hyp.pri_smooth_coeff, summ_writer) # pri_e = F.sigmoid(energy_vol) # energy_volbinary = torch.round(pri_e) # # collect some accuracy stats # pri_match = pri_g*torch.eq(energy_volbinary, pri_g).float() # free_match = free_g*torch.eq(1.0-energy_volbinary, free_g).float() # either_match = torch.clamp(pri_match+free_match, 0.0, 1.0) # either_have = torch.clamp(pri_g+free_g, 0.0, 1.0) # acc_pri = reduce_masked_mean(pri_match, pri_g*valid) # acc_free = reduce_masked_mean(free_match, free_g*valid) # acc_total = reduce_masked_mean(either_match, either_have*valid) # summ_writer.summ_scalar('pri/acc_pri%s' % suffix, acc_pri.cpu().item()) # summ_writer.summ_scalar('pri/acc_free%s' % suffix, acc_free.cpu().item()) # summ_writer.summ_scalar('pri/acc_total%s' % suffix, acc_total.cpu().item()) # # vis # summ_writer.summ_pri('pri/pri_g%s' % suffix, pri_g, reduce_axes=[2,3]) # summ_writer.summ_pri('pri/free_g%s' % suffix, free_g, reduce_axes=[2,3]) # summ_writer.summ_pri('pri/pri_e%s' % suffix, pri_e, reduce_axes=[2,3]) # summ_writer.summ_pri('pri/valid%s' % suffix, valid, reduce_axes=[2,3]) # prob_loss = self.compute_loss(energy_vol, pri_g, free_g, valid, summ_writer) # total_loss = utils_misc.add_loss('pri/prob_loss%s' % suffix, total_loss, prob_loss, hyp.pri_coeff, summ_writer) return total_loss #, pri_e