def compute_loss(self, pred, pos, neg, valid, summ_writer): label = pos * 2.0 - 1.0 a = -label * pred b = F.relu(a) loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) mask_ = (pos + neg > 0.0).float() loss_vis = torch.mean(loss * mask_ * valid, dim=3) summ_writer.summ_oned('sub/prob_loss', loss_vis) # pos_loss = reduce_masked_mean(loss, pos*valid) # neg_loss = reduce_masked_mean(loss, neg*valid) # balanced_loss = pos_loss + neg_loss # pos_occ_loss = utils_basic.reduce_masked_mean(loss, pos*valid*occ) # pos_free_loss = utils_basic.reduce_masked_mean(loss, pos*valid*free) # neg_occ_loss = utils_basic.reduce_masked_mean(loss, neg*valid*occ) # neg_free_loss = utils_basic.reduce_masked_mean(loss, neg*valid*free) # balanced_loss = pos_occ_loss + pos_free_loss + neg_occ_loss + neg_free_loss pos_loss = utils_basic.reduce_masked_mean(loss, pos * valid) neg_loss = utils_basic.reduce_masked_mean(loss, neg * valid) balanced_loss = pos_loss + neg_loss return balanced_loss
def forward(self, feat0, feat1, valid0, valid1, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat0.shape) neg_input = torch.cat([feat1-feat0], dim=1) feat1_flat = feat1.reshape(B, C, -1) valid1_flat = valid1.reshape(B, 1, -1) perm = np.random.permutation(D*H*W) feat1_flat_shuf = feat1_flat[:,:,perm] feat1_shuf = feat1_flat_shuf.reshape(B, C, D, H, W) valid1_flat_shuf = valid1_flat[:,:,perm] valid1_shuf = valid1_flat_shuf.reshape(B, 1, D, H, W) # pos_input = torch.cat([feat0, feat1_shuf-feat0], dim=1) pos_input = torch.cat([feat1_shuf-feat0], dim=1) # noncorresps should be rejected pos_output = self.net(pos_input) # corresps should NOT be rejected neg_output = self.net(neg_input) pos_sig = F.sigmoid(pos_output) neg_sig = F.sigmoid(neg_output) if summ_writer is not None: summ_writer.summ_feat('reject/pos_input', pos_input, pca=True) summ_writer.summ_feat('reject/neg_input', neg_input, pca=True) summ_writer.summ_oned('reject/pos_sig', pos_sig, bev=True, norm=False) summ_writer.summ_oned('reject/neg_sig', neg_sig, bev=True, norm=False) pos_output_vec = pos_output.reshape(B, D*H*W) neg_output_vec = neg_output.reshape(B, D*H*W) pos_target_vec = torch.ones([B, D*H*W]).float().cuda() neg_target_vec = torch.zeros([B, D*H*W]).float().cuda() # if feat1_shuf is valid, then it is practically guranateed to mismatch feat0 pos_valid_vec = valid1_shuf.reshape(B, D*H*W) # both have to be valid to not reject neg_valid_vec = (valid0*valid1).reshape(B, D*H*W) pos_loss_vec = self.criterion(pos_output_vec, pos_target_vec) neg_loss_vec = self.criterion(neg_output_vec, neg_target_vec) pos_loss = utils_basic.reduce_masked_mean(pos_loss_vec, pos_valid_vec) neg_loss = utils_basic.reduce_masked_mean(neg_loss_vec, pos_valid_vec) ce_loss = pos_loss + neg_loss utils_misc.add_loss('reject3D/ce_loss_pos', 0, pos_loss, 0, summ_writer) utils_misc.add_loss('reject3D/ce_loss_neg', 0, neg_loss, 0, summ_writer) total_loss = utils_misc.add_loss('reject3D/ce_loss', total_loss, ce_loss, hyp.reject3D_ce_coeff, summ_writer) return total_loss, neg_sig, pos_sig
def compute_loss(self, pred, seg, free, summ_writer): # pred is B x C x Z x Y x X # seg is B x Z x Y x X # ensure the "free" voxels are labelled zero seg = seg * (1 - free) # seg_bak = seg.clone() # note that label0 in seg is invalid, so i need to ignore these # but all "free" voxels count as the air class # seg = (seg-1).clamp(min=0) # seg = (seg-1).clamp(min=0) # # ignore_mask = (seg[seg==-1]).float() # seg[seg==-1] = 0 loss = F.cross_entropy(pred, seg, reduction='none') # loss is B x Z x Y x X loss_any = ((seg > 0).float() + (free > 0).float()).clamp(0, 1) loss_vis = torch.mean(loss * loss_any, dim=2).unsqueeze(1) summ_writer.summ_oned('seg/prob_loss', loss_vis) loss = loss.reshape(-1) # seg_bak = seg_bak.reshape(-1) seg = seg.reshape(-1) free = free.reshape(-1) # print('loss', loss.shape) # print('seg_bak', seg_bak.shape) losses = [] # total_loss = 0.0 # next, i want to gather up the loss for each valid class, and balance these into a total for cls in list(range(self.num_classes)): if cls == 0: mask = free.clone() else: # mask = (seg_bak==cls).float() mask = (seg == cls).float() # print('mask', mask.shape) # print('loss', loss.shape) cls_loss = utils_basic.reduce_masked_mean(loss, mask) print('cls %d sum' % cls, torch.sum(mask).detach().cpu().numpy(), 'loss', cls_loss.detach().cpu().numpy()) # print('cls_loss', cls_loss.shape) # print('cls %d loss' % cls, cls_loss.detach().cpu().numpy()) # total_loss = total_loss + cls_loss if torch.sum(mask) >= 1: losses.append(cls_loss) # print('mask', mask.shape) # loss_ = loss[seg_bak==cls] # print('loss_', loss_.shape) # loss_ = loss[seg_bak==cls] total_loss = torch.mean(torch.stack(losses)) return total_loss
def forward(self, feat, occ_g, free_g, summ_writer=None): B, C, Z, Y, X = list(feat.shape) total_loss = torch.tensor(0.0).cuda() summ_writer.summ_feat('genoccnet/feat_input', feat, pca=False) feat = feat.long() # discrete input logit = self.run_net(feat) # smooth loss dz, dy, dx = gradient3D(logit, absolute=True) smooth_vox = torch.mean(dx+dy+dx, dim=1, keepdims=True) summ_writer.summ_oned('genoccnet/smooth_loss', torch.mean(smooth_vox, dim=3)) # smooth_loss = utils_basic.reduce_masked_mean(smooth_vox, valid) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('genoccnet/smooth_loss', total_loss, smooth_loss, hyp.genocc_smooth_coeff, summ_writer) summ_writer.summ_feat('genoccnet/feat_output', logit, pca=False) occ_e = torch.argmax(logit, dim=1, keepdim=True) summ_writer.summ_occ('genoccnet/occ_e', occ_e) summ_writer.summ_occ('genoccnet/occ_g', occ_g) loss_pos = self.criterion(logit, (occ_g[:,0]).long()) loss_neg = self.criterion(logit, (1-free_g[:,0]).long()) summ_writer.summ_oned('genoccnet/loss', torch.mean((loss_pos+loss_neg), dim=1, keepdim=True) * \ torch.clamp(occ_g+(1-free_g), 0, 1), bev=True) loss_pos = utils_basic.reduce_masked_mean(loss_pos.unsqueeze(1), occ_g) loss_neg = utils_basic.reduce_masked_mean(loss_neg.unsqueeze(1), free_g) loss_bal = loss_pos + loss_neg total_loss = utils_misc.add_loss('genoccnet/loss_bal', total_loss, loss_bal, hyp.genocc_coeff, summ_writer) # sample = self.generate_sample(1, int(Z/2), int(Y/2), int(X/2)) # occ_sample = torch.argmax(sample, dim=1, keepdim=True) # summ_writer.summ_occ('genoccnet/occ_sample', occ_sample) return logit, total_loss
def forward(self, feat, rgb_g, valid, summ_writer, name): total_loss = torch.tensor(0.0).cuda() if hyp.dataset_name == "clevr": valid = torch.ones_like(valid) feat = self.net(feat) emb_e = self.emb_layer(feat) rgb_e = self.rgb_layer(feat) # postproc emb_e = l2_normalize(emb_e, dim=1) rgb_e = torch.nn.functional.tanh(rgb_e) * 0.5 loss_im = l1_on_axis(rgb_e - rgb_g, 1, keepdim=True) summ_writer.summ_oned('view/rgb_loss', loss_im * valid) rgb_loss = utils_basic.reduce_masked_mean(loss_im, valid) total_loss = utils_misc.add_loss('view/rgb_l1_loss', total_loss, rgb_loss, hyp.view_l1_coeff, summ_writer) # vis summ_writer.summ_rgbs(f'view/{name}', [rgb_e, rgb_g]) return total_loss, rgb_e, emb_e
def forward(self, image_input, summ_writer=None, is_train=True): B, H, W = list(image_input.shape) total_loss = torch.tensor(0.0).cuda() summ_writer.summ_oned('sigen2d/image_input', image_input.unsqueeze(1) / 512.0, norm=False) emb = self.embed(image_input) y = torch.randint(low=0, high=H, size=[B, self.num_choices, 1]) x = torch.randint(low=0, high=W, size=[B, self.num_choices, 1]) choice_mask = utils_improc.xy2mask(torch.cat([x, y], dim=2), H, W, norm=False) summ_writer.summ_oned('sigen2d/choice_mask', choice_mask, norm=False) # cover up the 3x3 region surrounding each choice xy = torch.cat([ torch.cat([x - 1, y - 1], dim=2), torch.cat([x + 0, y - 1], dim=2), torch.cat([x + 1, y - 1], dim=2), torch.cat([x - 1, y], dim=2), torch.cat([x + 0, y], dim=2), torch.cat([x + 1, y], dim=2), torch.cat([x - 1, y + 1], dim=2), torch.cat([x + 0, y + 1], dim=2), torch.cat([x + 1, y + 1], dim=2) ], dim=1) input_mask = 1.0 - utils_improc.xy2mask(xy, H, W, norm=False) # if is_train: # input_mask = (torch.rand((B, 1, H, W)).cuda() > 0.5).float() # else: # input_mask = torch.ones((B, 1, H, W)).cuda().float() # input_mask = input_mask * (1.0 - choice_mask) # input_mask = 1.0 - choice_mask emb = emb * input_mask summ_writer.summ_oned('sigen2d/input_mask', input_mask, norm=False) image_output_logits, _ = self.net(emb, input_mask) image_output = torch.argmax(image_output_logits, dim=1, keepdim=True) summ_writer.summ_feat('sigen2d/emb', emb, pca=True) summ_writer.summ_oned('sigen2d/image_output', image_output.float() / 512.0, norm=False) ce_loss_image = F.cross_entropy(image_output_logits, image_input, reduction='none').unsqueeze(1) # summ_writer.summ_oned('sigen2d/ce_loss', ce_loss_image) # ce_loss = torch.mean(ce_loss_image) # only apply loss at the choice pixels ce_loss = utils_basic.reduce_masked_mean(ce_loss_image, choice_mask) total_loss = utils_misc.add_loss('sigen2d/ce_loss', total_loss, ce_loss, hyp.sigen2d_coeff, summ_writer) return total_loss
def forward(self, feats, xyzlist_cam, scorelist, vislist, occs, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, S, C, Z2, Y2, X2 = list(feats.shape) B, S, C, Z, Y, X = list(occs.shape) B2, S2, D = list(xyzlist_cam.shape) assert (B == B2, S == S2) assert (D == 3) xyzlist_mem = utils_vox.Ref2Mem(xyzlist_cam, Z, Y, X) # these are B x S x 3 scorelist = scorelist.unsqueeze(2) # this is B x S x 1 vislist = vislist[:, 0].reshape(B, 1, 1) # we only care that the object was visible in frame0 scorelist = scorelist * vislist if self.use_cost_vols: if summ_writer.save_this: summ_writer.summ_traj_on_occ('forecast/actual_traj', xyzlist_mem * scorelist, torch.max(occs, dim=1)[0], already_mem=True, sigma=2) Z2, Y2, X2 = int(Z / 2), int(Y / 2), int(X / 2) Z4, Y4, X4 = int(Z / 4), int(Y / 4), int(X / 4) occ_hint0 = utils_vox.voxelize_xyz(xyzlist_cam[:, 0:1], Z4, Y4, X4) occ_hint1 = utils_vox.voxelize_xyz(xyzlist_cam[:, 1:2], Z4, Y4, X4) occ_hint0 = occ_hint0 * scorelist[:, 0].reshape(B, 1, 1, 1, 1) occ_hint1 = occ_hint1 * scorelist[:, 1].reshape(B, 1, 1, 1, 1) occ_hint = torch.cat([occ_hint0, occ_hint1], dim=1) occ_hint = F.interpolate(occ_hint, scale_factor=4, mode='nearest') # this is B x 1 x Z x Y x X summ_writer.summ_occ('forecast/occ_hint', (occ_hint0 + occ_hint1).clamp(0, 1)) crops = [] for s in list(range(S)): crop = utils_vox.center_mem_on_xyz(occs_highres[:, s], xyzlist_cam[:, s], Z2, Y2, X2) crops.append(crop) crops = torch.stack(crops, dim=0) summ_writer.summ_occs('forecast/crops', crops) # condition on the occ_hint feat = torch.cat([feat, occ_hint], dim=1) N = hyp.forecast_num_negs sampled_trajs_mem = self.sample_trajs_from_library(N, xyzlist_mem) if summ_writer.save_this: for n in list(range(np.min([N, 10]))): xyzlist_mem = sampled_trajs_mem[0, n].unsqueeze(0) # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'forecast/lib%d_xyzlist' % n, xyzlist_mem, torch.zeros([1, 1, Z, Y, X]).float().cuda(), already_mem=True) cost_vols = self.cost_forecaster(feat) # cost_vols = F.sigmoid(cost_vols) cost_vols = F.interpolate(cost_vols, scale_factor=2, mode='trilinear') # cost_vols is B x S x Z x Y x X summ_writer.summ_histogram('forecast/cost_vols_hist', cost_vols) cost_vols = cost_vols.clamp( -1000, 1000) # raquel says this adds stability summ_writer.summ_histogram('forecast/cost_vols_clamped_hist', cost_vols) cost_vols_vis = torch.mean(cost_vols, dim=3).unsqueeze(2) # cost_vols_vis is B x S x 1 x Z x X summ_writer.summ_oneds('forecast/cost_vols_vis', torch.unbind(cost_vols_vis, dim=1)) # smooth loss cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X) dz, dy, dx = gradient3D(cost_vols_, absolute=True) dt = torch.abs(cost_vols[:, 1:] - cost_vols[:, 0:-1]) smooth_vox_spatial = torch.mean(dx + dy + dz, dim=1, keepdims=True) smooth_vox_time = torch.mean(dt, dim=1, keepdims=True) summ_writer.summ_oned('forecast/smooth_loss_spatial', torch.mean(smooth_vox_spatial, dim=3)) summ_writer.summ_oned('forecast/smooth_loss_time', torch.mean(smooth_vox_time, dim=3)) smooth_loss = torch.mean(smooth_vox_spatial) + torch.mean( smooth_vox_time) total_loss = utils_misc.add_loss('forecast/smooth_loss', total_loss, smooth_loss, hyp.forecast_smooth_coeff, summ_writer) def clamp_xyz(xyz, X, Y, Z): x, y, z = torch.unbind(xyz, dim=-1) x = x.clamp(0, X) y = x.clamp(0, Y) z = x.clamp(0, Z) xyz = torch.stack([x, y, z], dim=-1) return xyz # obj_xyzlist_mem is K x B x S x 3 # xyzlist_mem is B x S x 3 # sampled_trajs_mem is B x N x S x 3 xyz_pos_ = xyzlist_mem.reshape(B * S, 1, 3) xyz_neg_ = sampled_trajs_mem.permute(0, 2, 1, 3).reshape(B * S, N, 3) # xyz_pos_ = clamp_xyz(xyz_pos_, X, Y, Z) # xyz_neg_ = clamp_xyz(xyz_neg_, X, Y, Z) xyz_ = torch.cat([xyz_pos_, xyz_neg_], dim=1) xyz_ = clamp_xyz(xyz_, X, Y, Z) cost_vols_ = cost_vols.reshape(B * S, 1, Z, Y, X) x, y, z = torch.unbind(xyz_, dim=2) # x = x.clamp(0, X) # y = x.clamp(0, Y) # z = x.clamp(0, Z) cost_ = utils_samp.bilinear_sample3D(cost_vols_, x, y, z).squeeze(1) # cost is B*S x 1+N cost_pos = cost_[:, 0:1] # B*S x 1 cost_neg = cost_[:, 1:] # B*S x N cost_pos = cost_pos.unsqueeze(2) # B*S x 1 x 1 cost_neg = cost_neg.unsqueeze(1) # B*S x 1 x N utils_misc.add_loss('forecast/mean_cost_pos', 0, torch.mean(cost_pos), 0, summ_writer) utils_misc.add_loss('forecast/mean_cost_neg', 0, torch.mean(cost_neg), 0, summ_writer) utils_misc.add_loss('forecast/mean_margin', 0, torch.mean(cost_neg - cost_pos), 0, summ_writer) xyz_pos = xyz_pos_.unsqueeze(2) # B*S x 1 x 1 x 3 xyz_neg = xyz_neg_.unsqueeze(1) # B*S x 1 x N x 3 dist = torch.norm(xyz_pos - xyz_neg, dim=3) # B*S x 1 x N dist = dist / float( Z) * 5.0 # normalize for resolution, but upweight it a bit margin = F.relu(cost_pos - cost_neg + dist) margin = margin.reshape(B, S, N) # mean over time (in the paper this is a sum) margin = utils_basic.reduce_masked_mean(margin, scorelist.repeat(1, 1, N), dim=1) # max over the negatives maxmargin = torch.max(margin, dim=1)[0] # B maxmargin_loss = torch.mean(maxmargin) total_loss = utils_misc.add_loss('forecast/maxmargin_loss', total_loss, maxmargin_loss, hyp.forecast_maxmargin_coeff, summ_writer) cost_neg = cost_neg.reshape(B, S, N)[0].detach().cpu().numpy() sampled_trajs_mem = sampled_trajs_mem.reshape(B, N, S, 3)[0:1] cost_neg = np.reshape(cost_neg, [S, N]) cost_neg = np.sum(cost_neg, axis=0) inds = np.argsort(cost_neg, axis=0) for n in list(range(2)): xyzlist_e_mem = sampled_trajs_mem[0:1, inds[n]] xyzlist_e_cam = utils_vox.Mem2Ref(xyzlist_e_mem, Z, Y, X) # this is B x S x 3 # if summ_writer.save_this and n==0: # print('xyzlist_e_cam', xyzlist_e_cam[0:1]) # print('xyzlist_g_cam', xyzlist_cam[0:1]) # print('scorelist', scorelist[0:1]) dist = torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2) # this is B x S meandist = utils_basic.reduce_masked_mean( dist, scorelist[0:1].squeeze(2)) utils_misc.add_loss('forecast/xyz_dist_%d' % n, 0, meandist, 0, summ_writer) # dist = torch.mean(torch.sum(torch.norm(xyzlist_cam[0:1] - xyzlist_e_cam[0:1], dim=2), dim=1)) # mpe = torch.mean(torch.norm(xyzlist_cam[0:1,int(S/2)] - xyzlist_e_cam[0:1,int(S/2)], dim=1)) # mpe = utils_basic.reduce_masked_mean(dist, scorelist[0:1]) # utils_misc.add_loss('forecast/xyz_mpe_%d' % n, 0, dist, 0, summ_writer) # epe = torch.mean(torch.norm(xyzlist_cam[0:1,-1] - xyzlist_e_cam[0:1,-1], dim=1)) # utils_misc.add_loss('forecast/xyz_epe_%d' % n, 0, dist, 0, summ_writer) if summ_writer.save_this: # plot the best and worst trajs # print('sorted costs:', cost_neg[inds]) for n in list(range(2)): ind = inds[n] # print('plotting good traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'forecast/best_sampled_traj%d' % n, xyzlist_e_mem, torch.max(occs[0:1], dim=1)[0], # torch.zeros([1, 1, Z, Y, X]).float().cuda(), already_mem=True, sigma=1) for n in list(range(2)): ind = inds[-(n + 1)] # print('plotting bad traj with cost %.2f' % (cost_neg[ind])) xyzlist_e_mem = sampled_trajs_mem[:, ind] # this is 1 x S x 3 summ_writer.summ_traj_on_occ( 'forecast/worst_sampled_traj%d' % n, xyzlist_e_mem, torch.max(occs[0:1], dim=1)[0], # torch.zeros([1, 1, Z, Y, X]).float().cuda(), already_mem=True, sigma=1) else: # use some timesteps as input feat_input = feats[:, :self.num_given].squeeze(2) # feat_input is B x self.num_given x ZZ x ZY x ZX ## regular bottle3D # vel_e = self.regressor(feat_input) ## sparse-invar bottle3D comp_mask = 1.0 - (feat_input == 0).all(dim=1, keepdim=True).float() summ_writer.summ_feat('forecast/feat_input', feat_input, pca=False) summ_writer.summ_feat('forecast/feat_comp_mask', comp_mask, pca=False) vel_e = self.regressor(feat_input, comp_mask) vel_e = vel_e.reshape(B, self.num_need, 3) vel_g = xyzlist_cam[:, self.num_given:] - xyzlist_cam[:, self.num_given - 1:-1] xyzlist_e = torch.zeros_like(xyzlist_cam) xyzlist_g = torch.zeros_like(xyzlist_cam) for s in list(range(S)): # print('s = %d' % s) if s < self.num_given: # print('grabbing from gt ind %s' % s) xyzlist_e[:, s] = xyzlist_cam[:, s] xyzlist_g[:, s] = xyzlist_cam[:, s] else: # print('grabbing from s-self.num_given, which is ind %d' % (s-self.num_given)) xyzlist_e[:, s] = xyzlist_e[:, s - 1] + vel_e[:, s - self.num_given] xyzlist_g[:, s] = xyzlist_g[:, s - 1] + vel_g[:, s - self.num_given] xyzlist_e_mem = utils_vox.Ref2Mem(xyzlist_e, Z, Y, X) xyzlist_g_mem = utils_vox.Ref2Mem(xyzlist_g, Z, Y, X) summ_writer.summ_traj_on_occ('forecast/traj_e', xyzlist_e_mem, torch.max(occs, dim=1)[0], already_mem=True, sigma=2) summ_writer.summ_traj_on_occ('forecast/traj_g', xyzlist_g_mem, torch.max(occs, dim=1)[0], already_mem=True, sigma=2) scorelist_here = scorelist[:, self.num_given:, 0] sql2 = torch.sum((vel_g - vel_e)**2, dim=2) ## yes weightmask weightmask = torch.arange(0, self.num_need, dtype=torch.float32, device=torch.device('cuda')) weightmask = torch.exp(-weightmask**(1. / 4)) # 1.0000, 0.3679, 0.3045, 0.2682, 0.2431, 0.2242, 0.2091, 0.1966, 0.1860, # 0.1769, 0.1689, 0.1618, 0.1555, 0.1497, 0.1445, 0.1397, 0.1353 weightmask = weightmask.reshape(1, self.num_need) l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here * weightmask) utils_misc.add_loss('forecast/l2_loss', 0, l2_loss, 0, summ_writer) # # no weightmask: # l2_loss = utils_basic.reduce_masked_mean(sql2, scorelist_here) # total_loss = utils_misc.add_loss('forecast/l2_loss', total_loss, l2_loss, hyp.forecast_l2_coeff, summ_writer) dist = torch.norm(xyzlist_e - xyzlist_g, dim=2) meandist = utils_basic.reduce_masked_mean(dist, scorelist[:, :, 0]) utils_misc.add_loss('forecast/xyz_dist_0', 0, meandist, 0, summ_writer) l2_loss_noexp = utils_basic.reduce_masked_mean(sql2, scorelist_here) # utils_misc.add_loss('forecast/vel_dist_noexp', 0, l2_loss, 0, summ_writer) total_loss = utils_misc.add_loss('forecast/l2_loss_noexp', total_loss, l2_loss_noexp, hyp.forecast_l2_coeff, summ_writer) return total_loss
def forward(self, feat0, feat1, flow_g=None, mask_g=None, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat0.shape) utils_basic.assert_same_shape(feat0, feat1) # feats = torch.cat([feat0, feat1], dim=0) # feats = self.compressor(feats) # feats = utils_basic.l2_normalize(feats, dim=1) # feat0, feat1 = feats[:B], feats[B:] flow_total = torch.zeros([B, 3, D, H, W]).float().cuda() feat1_aligned = feat1.clone() # summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0, [feat0, feat1_aligned]) feat_diff = torch.mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True)) utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0, summ_writer) for sc in self.scales: flow = self.generate_flow(feat0, feat1_aligned, sc) mask = torch.zeros_like(flow[:, 0:1]) # print('mask', mask.shape) mask[:, :, self.max_disp:-self.max_disp, self.max_disp:-self.max_disp, self.max_disp:-self.max_disp] = 1.0 flow = flow * mask flow_total = flow_total + flow # compositional LK: warp the original thing using the cumulative flow feat1_aligned = utils_samp.backwarp_using_3D_flow( feat1, flow_total) valid1_region = utils_samp.backwarp_using_3D_flow( torch.ones_like(feat1[:, 0:1]), flow_total) # summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc, [feat0, feat1_aligned], # valids=[torch.ones_like(valid1_region), valid1_region]) feat_diff = utils_basic.reduce_masked_mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True), valid1_region) utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff, 0, summ_writer) if flow_g is not None: # ok done inference # now for losses/metrics: l1_diff_3chan = self.smoothl1(flow_total, flow_g) l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True) l2_diff_3chan = self.mse(flow_total, flow_g) l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True) nonzero_mask = ((torch.sum(torch.abs(flow_g), axis=1, keepdim=True) > 0.01).float()) * mask_g * mask yeszero_mask = (1.0 - nonzero_mask) * mask_g * mask l1_loss = utils_basic.reduce_masked_mean(l1_diff, mask_g * mask) l2_loss = utils_basic.reduce_masked_mean(l2_diff, mask_g * mask) l1_loss_nonzero = utils_basic.reduce_masked_mean( l1_diff, nonzero_mask) l1_loss_yeszero = utils_basic.reduce_masked_mean( l1_diff, yeszero_mask) l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5 l2_loss_nonzero = utils_basic.reduce_masked_mean( l2_diff, nonzero_mask) l2_loss_yeszero = utils_basic.reduce_masked_mean( l2_diff, yeszero_mask) l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5 # clip = np.squeeze(torch.max(torch.abs(torch.mean(flow_g[0], dim=0))).detach().cpu().numpy()).item() clip = 3.0 if summ_writer is not None: # summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc, flow_total*mask_g, clip=clip) summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc, flow_total, clip=clip) summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc, flow_g, clip=clip) summ_writer.summ_oned('flow/mask_%.2f' % sc, mask, bev=True, norm=False) summ_writer.summ_feat('flow/flow_e_pca_%.2f' % sc, flow_total, pca=True) utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_balanced', 0, l1_loss_balanced, 0, summ_writer) total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, l1_loss, hyp.flow_l1_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/l2_loss', total_loss, l2_loss, hyp.flow_l2_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/warp', total_loss, feat_diff, hyp.flow_warp_coeff, summ_writer) # smooth loss dx, dy, dz = utils_basic.gradient3D(flow_total, absolute=True) smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True) if summ_writer is not None: summ_writer.summ_oned('flow/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss, smooth_loss, hyp.flow_smooth_coeff, summ_writer) return total_loss, flow_total
def forward(self, pix_T_cam0, cam0_T_cam1, feat_mem1, rgb_g, vox_util, valid=None, summ_writer=None, test=False, suffix=''): total_loss = torch.tensor(0.0).cuda() B, C, H, W = list(rgb_g.shape) PH, PW = hyp.PH, hyp.PW if (PH < H) or (PW < W): # print('H, W', H, W) # print('PH, PW', PH, PW) sy = float(PH) / float(H) sx = float(PW) / float(W) pix_T_cam0 = utils_geom.scale_intrinsics(pix_T_cam0, sx, sy) if valid is not None: valid = F.interpolate(valid, scale_factor=0.5, mode='nearest') rgb_g = F.interpolate(rgb_g, scale_factor=0.5, mode='bilinear') # feat_prep = self.prep_layer(feat_mem1) # feat_proj = utils_vox.apply_pixX_T_memR_to_voxR( # pix_T_cam0, cam0_T_cam1, feat_prep, # hyp.view_depth, PH, PW) feat_proj = vox_util.apply_pixX_T_memR_to_voxR(pix_T_cam0, cam0_T_cam1, feat_mem1, hyp.view_depth, PH, PW) # logspace_slices=(hyp.dataset_name=='carla')) # def flatten_depth(feat_3d): # B, C, Z, Y, X = list(feat_3d.shape) # feat_2d = feat_3d.view(B, C*Z, Y, X) # return feat_2d # feat_pool = self.pool_layer(feat_proj) # feat_im = flatten_depth(feat_pool) # rgb_e = self.decoder(feat_im) feat = self.net(feat_proj) rgb = self.rgb_layer(feat) emb = self.emb_layer(feat) emb = utils_basic.l2_normalize(emb, dim=1) # feat_im = self.net(feat_proj) # if hyp.do_emb2D: # emb_e = self.emb_layer(feat) # # postproc # emb_e = l2_normalize(emb_e, dim=1) # else: # emb_e = None if test: return None, rgb, None # loss_im = torch.mean(F.mse_loss(rgb, rgb_g, reduction='none'), dim=1, keepdim=True) loss_im = utils_basic.l1_on_axis(rgb - rgb_g, 1, keepdim=True) if valid is not None: rgb_loss = utils_basic.reduce_masked_mean(loss_im, valid) else: rgb_loss = torch.mean(loss_im) total_loss = utils_misc.add_loss('view/rgb_l1_loss', total_loss, rgb_loss, hyp.view_l1_coeff, summ_writer) # smooth loss dy, dx = utils_basic.gradient2D(rgb, absolute=True) smooth_im = torch.mean(dy + dx, dim=1, keepdims=True) if summ_writer is not None: summ_writer.summ_oned('view/smooth_loss', smooth_im) smooth_loss = torch.mean(smooth_im) total_loss = utils_misc.add_loss('view/smooth_loss', total_loss, smooth_loss, hyp.view_smooth_coeff, summ_writer) # vis if summ_writer is not None: summ_writer.summ_oned('view/rgb_loss', loss_im) summ_writer.summ_rgbs('view/rgb', [rgb.clamp(-0.5, 0.5), rgb_g]) summ_writer.summ_rgb('view/rgb_e', rgb.clamp(-0.5, 0.5)) summ_writer.summ_rgb('view/rgb_g', rgb_g.clamp(-0.5, 0.5)) summ_writer.summ_feat('view/emb', emb, pca=True) if valid is not None: summ_writer.summ_rgb('view/rgb_e_valid', valid * rgb.clamp(-0.5, 0.5)) summ_writer.summ_rgb('view/rgb_g_valid', valid * rgb_g.clamp(-0.5, 0.5)) return total_loss, rgb, emb
def forward( self, boxes_g, scores_g, feat_zyx, summ_writer, mask=None, ): if hyp.deeper_det: feat_zyx = self.resnet(feat_zyx) # st() total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = feat_zyx.shape _, N, _ = boxes_g.shape # dim-2 is xc,yc,zc,lx,ly,lz,rx,ry,rz total_loss = 0.0 pred_dim = self.pred_dim # total 7, 6 deltas, 1 objectness feat = feat_zyx.permute( 0, 1, 4, 3, 2) # get feat in xyz order, now B x C x X x Y x Z corners = utils_geom.transform_boxes_to_corners( boxes_g) # corners is B x N x 8 x 3, last dim in xyz order corners_max = torch.max(corners, dim=2)[0] # B x N x 3 corners_min = torch.min(corners, dim=2)[0] corners_min_max_g = torch.stack([corners_min, corners_max], dim=3) # this is B x N x 3 x 2 # trim down, to save some time N = hyp.K boxes_g = boxes_g[:, :N, :6] corners_min_max_g = corners_min_max_g[:, :N] scores_g = scores_g[:, :N] # B x N # boxes_g is [-0.5~63.5, -0.5~15.5, -0.5~63.5] centers_g = boxes_g[:, :, :3] # B x N x 3 # centers_g is B x N x 3 grid = meshgrid3D_xyz( B, Z, Y, X)[0] # just one grid please, this is X x Y x Z x 3 delta_positions_raw = centers_g.view(B, N, 1, 1, 1, 3) - grid.view( 1, 1, X, Y, Z, 3) # tf.summary.histogram('delta_positions_raw', delta_positions_raw) delta_positions = delta_positions_raw / hyp.det_anchor_size # tf.summary.histogram('delta_positions', delta_positions) lengths_g = boxes_g[:, :, 3:6] # B x N x 3 # tf.summary.histogram('lengths_g', lengths_g) delta_lengths = torch.log(lengths_g / hyp.det_anchor_size) delta_lengths = torch.max( delta_lengths, -1e6 * torch.ones_like(delta_lengths)) # to avoid -infs turning into nans # tf.summary.histogram('delta_lengths', delta_lengths) lengths_g = lengths_g.view(B, N, 1, 1, 1, 3).repeat(1, 1, X, Y, Z, 1) # B x N x X x Y x Z x 3 delta_lengths = delta_lengths.view(B, N, 1, 1, 1, 3).repeat( 1, 1, X, Y, Z, 1) # B x N x X x Y x Z x 3 valid_mask = scores_g.view(B, N, 1, 1, 1, 1).repeat(1, 1, X, Y, Z, 1) # B x N x X x Y x Z x 1 delta_gt = torch.cat([delta_positions, delta_lengths], -1) # B x N x X x Y x Z x 6 object_dist = torch.max(torch.abs(delta_positions_raw) / (lengths_g * 0.5 + 1e-5), dim=5)[0] # B x N x X x Y x Z object_dist_mask = (torch.ones_like(object_dist) - binarize(object_dist, 0.5)).unsqueeze( dim=5) # B x N x X x Y x Z x 1 object_dist_mask = object_dist_mask * valid_mask # B x N x X x Y x Z x 1 object_neg_dist_mask = torch.ones_like(object_dist) - binarize( object_dist, 0.8) object_neg_dist_mask = object_neg_dist_mask * valid_mask.squeeze( dim=5) # B x N x X x Y x Z anchor_deltas_gt = None for obj_id in list(range(N)): if anchor_deltas_gt is None: anchor_deltas_gt = delta_gt[:, obj_id, :, :, :, :] * object_dist_mask[:, obj_id, :, :, :, :] current_mask = object_dist_mask[:, obj_id, :, :, :, :] else: # don't overwrite anchor positions that are already taken overlap = current_mask * object_dist_mask[:, obj_id, :, :, :, :] anchor_deltas_gt += ( torch.ones_like(overlap) - overlap ) * delta_gt[:, obj_id, :, :, :, :] * object_dist_mask[:, obj_id, :, :, :, :] current_mask = current_mask + object_dist_mask[:, obj_id, :, :, :, :] current_mask = binarize(current_mask, 0.5) # tf.summary.histogram('anchor_deltas_gt', anchor_deltas_gt) # ok nice, these do not have any extreme values pos_equal_one = binarize(torch.sum(object_dist_mask, dim=1), 0.5).squeeze(dim=4) # B x X x Y x Z neg_equal_one = binarize(torch.sum(object_neg_dist_mask, dim=1), 0.5) neg_equal_one = torch.ones_like( neg_equal_one) - neg_equal_one # B x X x Y x Z pos_equal_one_sum = torch.sum(pos_equal_one, [1, 2, 3]) # B neg_equal_one_sum = torch.sum(neg_equal_one, [1, 2, 3]) # set min to one in case no object, to avoid nan pos_equal_one_sum_safe = torch.max( pos_equal_one_sum, torch.ones_like(pos_equal_one_sum)) # B neg_equal_one_sum_safe = torch.max( neg_equal_one_sum, torch.ones_like(neg_equal_one_sum)) # B pred = self.conv1(feat) # this is B x 7 x X x Y x Z pred = pred.permute(0, 2, 3, 4, 1) # B x X x Y x Z x 7 pred_anchor_deltas = pred[..., 1:] # B x X x Y x Z x 6 pred_objectness_logits = pred[..., 0] # B x X x Y x Z pred_objectness = torch.nn.functional.sigmoid( pred_objectness_logits) # B x X x Y x Z # pred_anchor_deltas = pred_anchor_deltas.cpu() # pred_objectness = pred_objectness.cpu() small_addon_for_BCE = 1e-6 overall_loss = torch.nn.functional.binary_cross_entropy_with_logits( input=pred_objectness_logits, target=pos_equal_one, reduction='none', ) if mask is not None: overall_loss = overall_loss * mask else: overall_loss = overall_loss cls_pos_loss = utils_basic.reduce_masked_mean(overall_loss, pos_equal_one) cls_neg_loss = utils_basic.reduce_masked_mean(overall_loss, neg_equal_one) loss_prob = torch.sum(hyp.alpha_pos * cls_pos_loss + hyp.beta_neg * cls_neg_loss) pos_mask = pos_equal_one.unsqueeze(dim=4) # B x X x Y x Z x 1 if mask is not None: loss_l1 = smooth_l1_loss(pos_mask * pred_anchor_deltas, pos_mask * anchor_deltas_gt) # B x X x Y x Z x 1 loss_l1 = loss_l1 * mask.unsqueeze(-1) else: loss_l1 = smooth_l1_loss(pos_mask * pred_anchor_deltas, pos_mask * anchor_deltas_gt) # B x X x Y x Z x 1 loss_reg = torch.sum( loss_l1 / pos_equal_one_sum_safe.view(-1, 1, 1, 1, 1)) / hyp.B total_loss = utils_misc.add_loss('det/detect_prob', total_loss, loss_prob, hyp.det_prob_coeff, summ_writer) total_loss = utils_misc.add_loss('det/detect_reg', total_loss, loss_reg, hyp.det_reg_coeff, summ_writer) # finally, turn the preds into hard boxes, with nms ( bs_selected_boxes_co, bs_selected_scores, bs_overlaps, ) = rpn_proposal_graph(pred_objectness, pred_anchor_deltas, scores_g, corners_min_max_g, iou_thresh=0.2) # these are lists of length B, each one leading with dim "?", since there is a variable number of objs per frame N = hyp.K * 2 tidlist = torch.linspace(1.0, N, N).long().to('cuda') tidlist = tidlist.unsqueeze(0).repeat(B, 1) padded_boxes_e = torch.zeros(B, N, 9).float().cuda() padded_scores_e = torch.zeros(B, N).float().cuda() if bs_selected_boxes_co is not None: for b in list(range(B)): # make the boxes 1 x N x 9 (instead of B x ? x 6) padded_boxes0_e = bs_selected_boxes_co[b].unsqueeze(0) padded_scores0_e = bs_selected_scores[b].unsqueeze(0) padded_boxes0_e = torch.cat([ padded_boxes0_e, torch.zeros([1, N, 6], device=torch.device('cuda')) ], dim=1) # 1 x ? x 6 padded_scores0_e = torch.cat([ padded_scores0_e, torch.zeros([1, N], device=torch.device('cuda')) ], dim=1) # pad out padded_boxes0_e = padded_boxes0_e[:, :N] # clip to N padded_scores0_e = padded_scores0_e[:, :N] # clip to N padded_boxes0_e = torch.cat([ padded_boxes0_e, torch.zeros([1, N, 3], device=torch.device('cuda')) ], dim=2) padded_boxes_e[b] = padded_boxes0_e[0] padded_scores_e[b] = padded_scores0_e[0] return total_loss, padded_boxes_e, padded_scores_e, tidlist, bs_selected_scores, bs_overlaps
def run_test(self, feed): results = dict() global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, self.B) __u = lambda x: utils_basic.unpack_seqdim(x, self.B) self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist( self.lrt_camX0s) self.original_centroid = self.scene_centroid.clone() obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s) obj_length = obj_lengths[:, 0] for b in list(range(self.B)): if self.score_s[b, 0] < 1.0: # we need the template to exist print('returning early, since score_s[%d,0] = %.1f' % (b, self.score_s[b, 0].cpu().numpy())) return total_loss, results, True # if torch.sum(self.score_s[b]) < (self.S/2): if not (torch.sum(self.score_s[b]) == self.S): # the full traj should be valid print( 'returning early, since sum(score_s) = %d, while S = %d' % (torch.sum(self.score_s).cpu().numpy(), self.S)) return total_loss, results, True if hyp.do_feat3D: feat_memX0_input = torch.cat([ self.occ_memX0s[:, 0], self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0], ], dim=1) _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input) B, C, Z, Y, X = list(feat_memX0.shape) S = self.S obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist( self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1) # only take the occupied voxels occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y, X) # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0 obj_mask_memX0 = obj_mask_memX0s[:, 0] # discard the known freespace _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, 0:1], self.xyz_camXs[:, 0:1], Z, Y, X, agg=True) free_memX0 = free_memX0_.squeeze(1) obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0) for b in list(range(self.B)): if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8: print( 'returning early, since there are not enough valid object points' ) return total_loss, results, True # for b in list(range(self.B)): # sum_b = torch.sum(obj_mask_memX0[b]) # print('sum_b', sum_b.detach().cpu().numpy()) # if sum_b > 1000: # obj_mask_memX0[b] *= occ_memX0[b] # sum_b = torch.sum(obj_mask_memX0[b]) # print('reducing this to', sum_b.detach().cpu().numpy()) feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat0_vec = feat0_vec.permute(0, 2, 1) # this is B x huge x C obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round() occ_mask0_vec = occ_memX0.reshape(B, -1).round() free_mask0_vec = free_memX0.reshape(B, -1).round() # these are B x huge orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X) # this is B x huge x 3 obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist( self.lrt_camX0s) obj_length = obj_lengths[:, 0] cam0_T_obj = cams_T_obj0[:, 0] # this is B x S x 4 x 4 mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X) cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X) lrt_camIs_g = self.lrt_camX0s.clone() lrt_camIs_e = torch.zeros_like(self.lrt_camX0s) # we will fill this up ious = torch.zeros([B, S]).float().cuda() point_counts = np.zeros([B, S]) inb_counts = np.zeros([B, S]) feat_vis = [] occ_vis = [] for s in range(self.S): if not (s == 0): # remake the vox util and all the mem data self.scene_centroid = utils_geom.get_clist_from_lrtlist( lrt_camIs_e[:, s - 1:s])[:, 0] delta = self.scene_centroid - self.original_centroid self.vox_util = vox_util.Vox_util( self.Z, self.Y, self.X, self.set_name, scene_centroid=self.scene_centroid, assert_cube=True) self.occ_memXs = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z, self.Y, self.X)) self.occ_memX0s = __u( self.vox_util.voxelize_xyz(__p(self.xyz_camX0s), self.Z, self.Y, self.X)) self.unp_memXs = __u( self.vox_util.unproject_rgb_to_mem( __p(self.rgb_camXs), self.Z, self.Y, self.X, __p(self.pix_T_cams))) self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs( self.camX0s_T_camXs, self.unp_memXs) self.summ_writer.summ_occ('track/reloc_occ_%d' % s, self.occ_memX0s[:, s]) else: self.summ_writer.summ_occ('track/init_occ_%d' % s, self.occ_memX0s[:, s]) delta = torch.zeros([B, 3]).float().cuda() # print('scene centroid:', self.scene_centroid.detach().cpu().numpy()) occ_vis.append( self.summ_writer.summ_occ('', self.occ_memX0s[:, s], only_return=True)) # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False)) inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s], self.Z4, self.Y4, self.X, already_mem=False) num_inb = torch.sum(inb.float(), axis=1) # print('num_inb', num_inb, num_inb.shape) inb_counts[:, s] = num_inb.cpu().numpy() feat_memI_input = torch.cat([ self.occ_memX0s[:, s], self.unp_memX0s[:, s] * self.occ_memX0s[:, s], ], dim=1) _, feat_memI, valid_memI = self.featnet3D(feat_memI_input) self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s, feat_memI_input, pca=True) self.summ_writer.summ_feat('3D_feats/feat_%d' % s, feat_memI, pca=True) feat_vis.append( self.summ_writer.summ_feat('', feat_memI, pca=True, only_return=True)) # collect freespace here, to discard bad matches _, free_memI_, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs[:, s:s + 1], self.xyz_camXs[:, s:s + 1], Z, Y, X, agg=True) free_memI = free_memI_.squeeze(1) feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1) # this is B x C x huge feat_vec = feat_vec.permute(0, 2, 1) # this is B x huge x C memI_T_mem0 = utils_geom.eye_4x4(B) # we will fill this up # # put these on cpu, to save mem # feat0_vec = feat0_vec.detach().cpu() # feat_vec = feat_vec.detach().cpu() # to simplify the impl, we will iterate over the batch dim for b in list(range(B)): feat_vec_b = feat_vec[b] feat0_vec_b = feat0_vec[b] obj_mask0_vec_b = obj_mask0_vec[b] occ_mask0_vec_b = occ_mask0_vec[b] free_mask0_vec_b = free_mask0_vec[b] orig_xyz_b = orig_xyz[b] # these are huge x C careful = False if careful: # start with occ points, since these are definitely observed obj_inds_b = torch.where( (occ_mask0_vec_b * obj_mask0_vec_b) > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # also take random non-free non-occ points in the mask ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * ( 1.0 - free_mask0_vec_b) alt_inds_b = torch.where(ok_mask > 0) alt_vec_b = feat0_vec_b[alt_inds_b] alt_xyz0 = orig_xyz_b[alt_inds_b] # these are med x C # issues arise when "med" is too large num = len(alt_xyz0) max_pts = 2000 if num > max_pts: # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) alt_vec_b = alt_vec_b[perm[:max_pts]] alt_xyz0 = alt_xyz0[perm[:max_pts]] obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0) xyz0 = torch.cat([xyz0, alt_xyz0], dim=0) if s == 0: print('have %d pts in total' % (len(xyz0))) else: # take any points within the mask obj_inds_b = torch.where(obj_mask0_vec_b > 0) obj_vec_b = feat0_vec_b[obj_inds_b] xyz0 = orig_xyz_b[obj_inds_b] # these are med x C # issues arise when "med" is too large # trim down to max_pts num = len(xyz0) max_pts = 2000 if num > max_pts: print( 'have %d pts; taking a random set of %d pts inside' % (num, max_pts)) perm = np.random.permutation(num) obj_vec_b = obj_vec_b[perm[:max_pts]] xyz0 = xyz0[perm[:max_pts]] obj_vec_b = obj_vec_b.permute(1, 0) # this is is C x med corr_b = torch.matmul(feat_vec_b, obj_vec_b) # this is huge x med heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X) # this is med x 1 x Z4 x Y4 x X4 # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # # for numerical stability, we sub the max, and mult by the resolution # heat_b_ = heat_b.reshape(-1, Z*Y*X) # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b = heat_b - heat_b_max # heat_b = heat_b * float(len(heat_b[0].reshape(-1))) # heat_b_ = heat_b.reshape(-1, Z*Y*X) # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1) # heat_b_min = (torch.min(heat_b_).values) # free_b = free_memI[b:b+1] # print('free_b', free_b.shape) # print('heat_b', heat_b.shape) # heat_b[free_b > 0.0] = heat_b_min # make the min zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_min = (torch.min(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_min # zero out the freespace heat_b = heat_b * (1.0 - free_memI[b:b + 1]) # make the max zero heat_b_ = heat_b.reshape(-1, Z * Y * X) heat_b_max = (torch.max(heat_b_, dim=1).values).reshape( -1, 1, 1, 1, 1) heat_b = heat_b - heat_b_max # scale up, for numerical stability heat_b = heat_b * float(len(heat_b[0].reshape(-1))) xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True) # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True) # this is med x 3 xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y, X) xyzI_cam += delta xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1) memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI) # record #points, since ransac depends on this point_counts[b, s] = len(xyz0) # done stepping through batch mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0) cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI, mem_T_cam) # eval camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0, mem_T_cam, cam0_T_obj) # this is B x 4 x 4 lrt_camIs_e[:, s] = utils_geom.merge_lrt(obj_length, camI_T_obj) ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:, s:s + 1]).squeeze(1) results['ious'] = ious # if ious[0,-1] > 0.5: # print('returning early, since acc is too high') # return total_loss, results, True self.summ_writer.summ_rgbs('track/feats', feat_vis) self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False) for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) self.summ_writer.summ_scalar('track/point_counts', np.mean(point_counts)) # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item()) self.summ_writer.summ_scalar('track/inb_counts', np.mean(inb_counts)) lrt_camX0s_e = lrt_camIs_e.clone() lrt_camXs_e = utils_geom.apply_4x4s_to_lrts( self.camXs_T_camX0s, lrt_camX0s_e) if self.include_vis: visX_e = [] for s in list(range(self.S)): visX_e.append( self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s, self.rgb_camXs[:, s], lrt_camXs_e[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e) visX_g = [] for s in list(range(self.S)): visX_g.append( self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s, self.rgb_camXs[:, s], self.lrt_camXs[:, s:s + 1], self.score_s[:, s:s + 1], self.tid_s[:, s:s + 1], self.pix_T_cams[:, 0], only_return=True)) self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2) # this is B x S mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s) median_dist = utils_basic.reduce_masked_median(dists, self.score_s) # this is [] self.summ_writer.summ_scalar('track/centroid_dist_mean', mean_dist.cpu().item()) self.summ_writer.summ_scalar('track/centroid_dist_median', median_dist.cpu().item()) # if self.include_vis: if (True): self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) total_loss += mean_dist # we won't backprop, but it's nice to plot and print this anyway else: ious = torch.zeros([self.B, self.S]).float().cuda() for s in list(range(self.S)): ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists( self.lrt_camX0s[:, 0:1], self.lrt_camX0s[:, s:s + 1]).squeeze(1) results['ious'] = ious for s in range(self.S): self.summ_writer.summ_scalar( 'track/mean_iou_%02d' % s, torch.mean(ious[:, s]).cpu().item()) self.summ_writer.summ_scalar('track/mean_iou', torch.mean(ious).cpu().item()) lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1) obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e) self.summ_writer.summ_traj_on_occ('track/traj_e', obj_clist_camX0_e, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_traj_on_occ('track/traj_g', self.obj_clist_camX0, self.occ_memX0s[:, 0], self.vox_util, already_mem=False, sigma=2) self.summ_writer.summ_scalar('loss', total_loss.cpu().item()) return total_loss, results, False
def run_train(self, feed): results = dict() global_step = feed['global_step'] total_loss = torch.tensor(0.0).cuda() __p = lambda x: utils_basic.pack_seqdim(x, self.B) __u = lambda x: utils_basic.unpack_seqdim(x, self.B) if hyp.do_feat3D: feat_memX0s_input = torch.cat([ self.occ_memX0s, self.unp_memX0s * self.occ_memX0s, ], dim=2) feat3D_loss, feat_memX0s_, valid_memX0s_ = self.featnet3D( __p(feat_memX0s_input[:, 1:]), self.summ_writer, ) feat_memX0s = __u(feat_memX0s_) valid_memX0s = __u(valid_memX0s_) total_loss += feat3D_loss feat_memX0 = utils_basic.reduce_masked_mean( feat_memX0s, valid_memX0s.repeat(1, 1, hyp.feat3D_dim, 1, 1, 1), dim=1) valid_memX0 = torch.sum(valid_memX0s, dim=1).clamp(0, 1) self.summ_writer.summ_feat('3D_feats/feat_memX0', feat_memX0, valid=valid_memX0, pca=True) self.summ_writer.summ_feat('3D_feats/valid_memX0', valid_memX0, pca=False) if hyp.do_emb3D: _, altfeat_memX0, altvalid_memX0 = self.featnet3D_slow( feat_memX0s_input[:, 0]) self.summ_writer.summ_feat('3D_feats/altfeat_memX0', altfeat_memX0, valid=altvalid_memX0, pca=True) self.summ_writer.summ_feat('3D_feats/altvalid_memX0', altvalid_memX0, pca=False) if hyp.do_emb3D: if hyp.do_feat3D: _, _, Z_, Y_, X_ = list(feat_memX0.shape) else: Z_, Y_, X_ = self.Z2, self.Y2, self.X2 # Z_, Y_, X_ = self.Z, self.Y, self.X occ_memX0s, free_memX0s, _, _ = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=False) not_ok = torch.zeros_like(occ_memX0s[:, 0]) # it's not ok for a voxel to be marked occ only once not_ok += (torch.sum(occ_memX0s, dim=1) == 1.0).float() # it's not ok for a voxel to be marked occ AND free occ_agg = torch.sum(occ_memX0s, dim=1).clamp(0, 1) free_agg = torch.sum(free_memX0s, dim=1).clamp(0, 1) have_either = (occ_agg + free_agg).clamp(0, 1) have_both = occ_agg * free_agg not_ok += have_either * have_both # it's not ok for a voxel to be totally unobserved not_ok += (have_either == 0.0).float() not_ok = not_ok.clamp(0, 1) self.summ_writer.summ_occ('rely/not_ok', not_ok) self.summ_writer.summ_occ( 'rely/not_ok_occ', not_ok * torch.max(self.occ_memX0s_half, dim=1)[0]) self.summ_writer.summ_occ( 'rely/ok_occ', (1.0 - not_ok) * torch.max(self.occ_memX0s_half, dim=1)[0]) self.summ_writer.summ_occ( 'rely/aggressive_occ', torch.max(self.occ_memX0s_half, dim=1)[0]) be_safe = False if hyp.do_feat3D and be_safe: # update the valid masks valid_memX0 = valid_memX0 * (1.0 - not_ok) altvalid_memX0 = altvalid_memX0 * (1.0 - not_ok) if hyp.do_occ: _, _, Z_, Y_, X_ = list(feat_memX0.shape) occ_memX0_sup, free_memX0_sup, _, free_memX0s = self.vox_util.prep_occs_supervision( self.camX0s_T_camXs, self.xyz_camXs, Z_, Y_, X_, agg=True) self.summ_writer.summ_occ('occ_sup/occ_sup', occ_memX0_sup) self.summ_writer.summ_occ('occ_sup/free_sup', free_memX0_sup) self.summ_writer.summ_occs('occ_sup/freeX0s_sup', torch.unbind(free_memX0s, dim=1)) self.summ_writer.summ_occs( 'occ_sup/occX0s_sup', torch.unbind(self.occ_memX0s_half, dim=1)) occ_loss, occ_memX0_pred = self.occnet(altfeat_memX0, occ_memX0_sup, free_memX0_sup, altvalid_memX0, self.summ_writer) total_loss += occ_loss if hyp.do_emb3D: # compute 3D ML emb_loss_3D = self.embnet3D(feat_memX0, altfeat_memX0, valid_memX0.round(), altvalid_memX0.round(), self.summ_writer) total_loss += emb_loss_3D self.summ_writer.summ_scalar('loss', total_loss.cpu().item()) return total_loss, results, False
def get_gt_flow(obj_lrtlist_camRs, obj_scorelist, camRs_T_camXs, Z, Y, X, K=2, mod='', vis=True, summ_writer=None): # this constructs the flow field according to the given # box trajectories (obj_lrtlist_camRs) (collected from a moving camR) # and egomotion (encoded in camRs_T_camXs) # (so they do not take into account egomotion) # so, we first generate the flow for all the objects, # then in the background, put the ego flow N, B, S, D = list(obj_lrtlist_camRs.shape) assert (S == 2) # as a flow util, this expects S=2 flows = [] masks = [] for k in list(range(K)): obj_masklistR0 = utils_vox.assemble_padded_obj_masklist( obj_lrtlist_camRs[k, :, 0:1], obj_scorelist[k, :, 0:1], Z, Y, X, coeff=1.0) # this is B x 1(N) x 1(C) x Z x Y x Z # obj_masklistR0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X obj_mask0 = obj_masklistR0.squeeze(1) # this is B x 1 x Z x Y x X camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) # camR0_T_camR1 = camR0_T_camRs[:,1] # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1) # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0) # if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0) # summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1) if vis and (summ_writer is not None): # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0) summ_writer.summ_oned('flow/obj%d_mask0_%s' % (k, mod), torch.mean(obj_mask0, 3)) _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k]) # this is B x S x 4 x 4 ref_T_obj0 = ref_T_objs_list[:, 0] ref_T_obj1 = ref_T_objs_list[:, 1] obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0) obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1) # these are B x 4 x 4 mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref) cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) # only use these displaced points within the obj mask # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3) obj_mask0 = obj_mask0.view(B, Z, Y, X, 1) # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0 # cond = (obj_mask03 < 1.0).float() cond = (obj_mask0 > 0.0).float() xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0 flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3) # if vis and k==0: if vis: summ_writer.summ_3D_flow('flow/gt_%d_%s' % (k, mod), flow, clip=4.0) masks.append(obj_mask0) flows.append(flow) camR_T_cam0 = camRs_T_camXs[:, 0] camR_T_cam1 = camRs_T_camXs[:, 1] cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0) cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1) mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X) ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X) cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0) mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem) xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X) xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0) xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3) xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3) flow = xyz_mem1 - xyz_mem0 flow = flow.permute(0, 4, 1, 2, 3) bkg_flow = flow # allow zero motion in the bkg any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0] masks.append(1.0 - any_mask) flows.append(bkg_flow) flows = torch.stack(flows, axis=0) masks = torch.stack(masks, axis=0) masks = masks.repeat(1, 1, 3, 1, 1, 1) flow = utils_basic.reduce_masked_mean(flows, masks, dim=0) if vis: summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0) # flow is shaped B x 3 x D x H x W return flow
def forward(self, feat0, feat1, flow_g, mask_g, is_synth, summ_writer): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat0.shape) utils_basic.assert_same_shape(feat0, feat1) # feats = torch.cat([feat0, feat1], dim=0) # feats = self.compressor(feats) # feats = utils_basic.l2_normalize(feats, dim=1) # feat0, feat1 = feats[:B], feats[B:] flow_total_forw = torch.zeros_like(flow_g) flow_total_back = torch.zeros_like(flow_g) feat0_aligned = feat0.clone() feat1_aligned = feat1.clone() # cycle_losses = [] # l1_losses = [] # torch does not like it when we overwrite, so let's pre-allocate l1_loss_cumu = torch.tensor(0.0).cuda() l2_loss_cumu = torch.tensor(0.0).cuda() warp_loss_cumu = torch.tensor(0.0).cuda() summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0, [feat0, feat1_aligned]) feat_diff = torch.mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True)) utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0, summ_writer) # print('feat0, feat1_aligned, mask') # print(feat0.shape) # print(feat1_aligned.shape) # print(mask_g.shape) hinge_loss_vox = utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True) hinge_loss_vox = F.relu(0.2 - hinge_loss_vox) summ_writer.summ_oned('flow/hinge_loss', torch.mean(hinge_loss_vox, dim=3)) hinge_mask_vox = (torch.sum(torch.abs(flow_g), dim=1, keepdim=True) > 1.0).float() summ_writer.summ_oned('flow/hinge_mask', torch.mean(hinge_mask_vox, dim=3), norm=False) # hinge_loss = torch.mean(hinge_loss_vox) hinge_loss = utils_basic.reduce_masked_mean(hinge_loss_vox, hinge_mask_vox) total_loss = utils_misc.add_loss('flow/hinge', total_loss, hinge_loss, hyp.flow_hinge_coeff, summ_writer) for sc in self.scales: # flow_forw, new_feat0, new_feat1 = self.generate_flow(feat0, feat1_aligned, sc) # flow_back, new_feat1, new_feat1 = self.generate_flow(feat1, feat0_aligned, sc) # flow_forw, heat = self.generate_flow(feat0, feat1_aligned, sc) flow_forw = self.generate_flow(feat0, feat1_aligned, sc) flow_back = self.generate_flow(feat1, feat0_aligned, sc) flow_total_forw = flow_total_forw + flow_forw flow_total_back = flow_total_back + flow_back # compositional LK: warp the original thing using the cumulative flow feat1_aligned = utils_samp.backwarp_using_3D_flow( feat1, flow_total_forw) feat0_aligned = utils_samp.backwarp_using_3D_flow( feat0, flow_total_back) valid1_region = utils_samp.backwarp_using_3D_flow( torch.ones_like(feat1[:, 0:1]), flow_total_forw) valid0_region = utils_samp.backwarp_using_3D_flow( torch.ones_like(feat0[:, 0:1]), flow_total_forw) summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc, [feat0, feat1_aligned], valids=[valid0_region, valid1_region]) # summ_writer.summ_oned('flow/mean_heat_%.2f' % sc, torch.mean(heat, dim=3)) # feat_diff = torch.mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True)) feat_diff = utils_basic.reduce_masked_mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True), valid1_region * valid0_region) utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff, 0, summ_writer) if sc == 1.0: warp_loss_cumu = warp_loss_cumu + feat_diff * sc l1_diff_3chan = self.smoothl1(flow_total_forw, flow_g) l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True) l2_diff_3chan = self.mse(flow_total_forw, flow_g) l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True) nonzero_mask = ( (torch.sum(torch.abs(flow_g), axis=1, keepdim=True) > 0.01).float()) * mask_g yeszero_mask = (1.0 - nonzero_mask) * mask_g l1_loss_nonzero = utils_basic.reduce_masked_mean( l1_diff, nonzero_mask) l1_loss_yeszero = utils_basic.reduce_masked_mean( l1_diff, yeszero_mask) l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5 l2_loss_nonzero = utils_basic.reduce_masked_mean( l2_diff, nonzero_mask) l2_loss_yeszero = utils_basic.reduce_masked_mean( l2_diff, yeszero_mask) l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5 # l1_loss_cumu = l1_loss_cumu + l1_loss_balanced*sc l1_loss_cumu = l1_loss_cumu + l1_loss_balanced * sc l2_loss_cumu = l2_loss_cumu + l2_loss_balanced * sc # warp flow flow_back_aligned_to_forw = utils_samp.backwarp_using_3D_flow( flow_total_back, flow_total_forw.detach()) flow_forw_aligned_to_back = utils_samp.backwarp_using_3D_flow( flow_total_forw, flow_total_back.detach()) cancelled_flow_forw = flow_total_forw + flow_back_aligned_to_forw cancelled_flow_back = flow_total_back + flow_forw_aligned_to_back cycle_forw = self.smoothl1_mean( cancelled_flow_forw, torch.zeros_like(cancelled_flow_forw)) cycle_back = self.smoothl1_mean( cancelled_flow_back, torch.zeros_like(cancelled_flow_back)) cycle_loss = cycle_forw + cycle_back total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss, cycle_loss, hyp.flow_cycle_coeff, summ_writer) summ_writer.summ_3D_flow('flow/flow_e_forw_%.2f' % sc, flow_total_forw * mask_g, clip=0.0) summ_writer.summ_3D_flow('flow/flow_e_back_%.2f' % sc, flow_total_back, clip=0.0) summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc, flow_g, clip=0.0) utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_balanced', 0, l1_loss_balanced, 0, summ_writer) utils_misc.add_loss('flow/l2_loss_balanced', 0, l2_loss_balanced, 0, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, l1_loss, hyp.flow_l1_coeff*(sc==1.0), summ_writer) if is_synth: total_loss = utils_misc.add_loss('flow/synth_l1_cumu', total_loss, l1_loss_cumu, hyp.flow_synth_l1_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/synth_l2_cumu', total_loss, l2_loss_cumu, hyp.flow_synth_l2_coeff, summ_writer) else: total_loss = utils_misc.add_loss('flow/l1_cumu', total_loss, l1_loss_cumu, hyp.flow_l1_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/l2_cumu', total_loss, l2_loss_cumu, hyp.flow_l2_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/warp', total_loss, feat_diff, hyp.flow_warp_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/warp_cumu', total_loss, warp_loss_cumu, hyp.flow_warp_coeff, summ_writer) # feat1_aligned = utils_samp.backwarp_using_3D_flow(feat1, flow_g) # valid_region = utils_samp.backwarp_using_3D_flow(torch.ones_like(feat1[:,0:1]), flow_g) # summ_writer.summ_feats('flow/feats_aligned_g', [feat0, feat1_aligned], # valids=[valid_region, valid_region]) # feat_diff = utils_basic.reduce_masked_mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True), valid_region) # total_loss = utils_misc.add_loss('flow/warp_g', total_loss, feat_diff, hyp.flow_warp_g_coeff, summ_writer) # # hinge_loss_vox = F.relu(0.2 - hinge_loss_vox) # total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss, torch.sum(torch.stack(cycle_losses)), hyp.flow_cycle_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, torch.sum(torch.stack(l1_losses)), hyp.flow_l1_coeff, summ_writer) # smooth loss dx, dy, dz = utils_basic.gradient3D(flow_total_forw, absolute=True) smooth_vox_forw = torch.mean(dx + dy + dx, dim=1, keepdims=True) dx, dy, dz = utils_basic.gradient3D(flow_total_back, absolute=True) smooth_vox_back = torch.mean(dx + dy + dx, dim=1, keepdims=True) summ_writer.summ_oned('flow/smooth_loss_forw', torch.mean(smooth_vox_forw, dim=3)) smooth_loss = torch.mean((smooth_vox_forw + smooth_vox_back) * 0.5) total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss, smooth_loss, hyp.flow_smooth_coeff, summ_writer) # flow_e = F.sigmoid(flow_e_) # flow_e_binary = torch.round(flow_e) # # collect some accuracy stats # flow_match = flow_g*torch.eq(flow_e_binary, flow_g).float() # free_match = free_g*torch.eq(1.0-flow_e_binary, free_g).float() # either_match = torch.clamp(flow_match+free_match, 0.0, 1.0) # either_have = torch.clamp(flow_g+free_g, 0.0, 1.0) # acc_flow = reduce_masked_mean(flow_match, flow_g*valid) # acc_free = reduce_masked_mean(free_match, free_g*valid) # acc_total = reduce_masked_mean(either_match, either_have*valid) # summ_writer.summ_scalar('flow/acc_flow', acc_flow.cpu().item()) # summ_writer.summ_scalar('flow/acc_free', acc_free.cpu().item()) # summ_writer.summ_scalar('flow/acc_total', acc_total.cpu().item()) # # vis # summ_writer.summ_flow('flow/flow_g', flow_g) # summ_writer.summ_flow('flow/free_g', free_g) # summ_writer.summ_flow('flow/flow_e', flow_e) # summ_writer.summ_flow('flow/valid', valid) # prob_loss = self.compute_loss(flow_e_, flow_g, free_g, valid, summ_writer) # total_loss = utils_misc.add_loss('flow/prob_loss', total_loss, prob_loss, hyp.flow_coeff, summ_writer) # return total_loss, flow_e return total_loss, flow_total_forw
def forward(self, feat, obj_lrtlist_cams, obj_scorelist_s, summ_writer, suffix=''): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat.shape) K, B2, S, D = list(obj_lrtlist_cams.shape) assert (B == B2) # obj_scorelist_s is K x B x S # __p = lambda x: utils_basic.pack_seqdim(x, B) # __u = lambda x: utils_basic.unpack_seqdim(x, B) # obj_lrtlist_cams is K x B x S x 19 obj_lrtlist_cams_ = obj_lrtlist_cams.reshape(K * B, S, 19) obj_clist_cam_ = utils_geom.get_clist_from_lrtlist(obj_lrtlist_cams_) obj_clist_cam = obj_clist_cam_.reshape(K, B, S, 1, 3) # obj_clist_cam is K x B x S x 1 x 3 obj_clist_cam = obj_clist_cam.squeeze(3) # obj_clist_cam is K x B x S x 3 clist_cam = obj_clist_cam.reshape(K * B, S, 3) clist_mem = utils_vox.Ref2Mem(clist_cam, Z, Y, X) # this is K*B x S x 3 clist_mem = clist_mem.reshape(K, B, S, 3) energy_vol = self.conv3d(feat) # energy_vol is B x 1 x Z x Y x X summ_writer.summ_oned('pri/energy_vol', torch.mean(energy_vol, dim=3)) summ_writer.summ_histogram('pri/energy_vol_hist', energy_vol) # for k in range(K): # let's start with the first object # loglike_per_traj = self.get_traj_loglike(clist_mem[0], energy_vol) # # this is B # ce_loss = -1.0*torch.mean(loglike_per_traj) # # this is [] loglike_per_traj = self.get_trajs_loglike(clist_mem, obj_scorelist_s, energy_vol) # this is B x K valid = torch.max(obj_scorelist_s.permute(1, 0, 2), dim=2)[0] ce_loss = -1.0 * utils_basic.reduce_masked_mean( loglike_per_traj, valid) # this is [] total_loss = utils_misc.add_loss('pri/ce_loss', total_loss, ce_loss, hyp.pri_ce_coeff, summ_writer) reg_loss = torch.sum(torch.abs(energy_vol)) total_loss = utils_misc.add_loss('pri/reg_loss', total_loss, reg_loss, hyp.pri_reg_coeff, summ_writer) # smooth loss dz, dy, dx = gradient3D(energy_vol, absolute=True) smooth_vox = torch.mean(dx + dy + dx, dim=1, keepdims=True) summ_writer.summ_oned('pri/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('pri/smooth_loss', total_loss, smooth_loss, hyp.pri_smooth_coeff, summ_writer) # pri_e = F.sigmoid(energy_vol) # energy_volbinary = torch.round(pri_e) # # collect some accuracy stats # pri_match = pri_g*torch.eq(energy_volbinary, pri_g).float() # free_match = free_g*torch.eq(1.0-energy_volbinary, free_g).float() # either_match = torch.clamp(pri_match+free_match, 0.0, 1.0) # either_have = torch.clamp(pri_g+free_g, 0.0, 1.0) # acc_pri = reduce_masked_mean(pri_match, pri_g*valid) # acc_free = reduce_masked_mean(free_match, free_g*valid) # acc_total = reduce_masked_mean(either_match, either_have*valid) # summ_writer.summ_scalar('pri/acc_pri%s' % suffix, acc_pri.cpu().item()) # summ_writer.summ_scalar('pri/acc_free%s' % suffix, acc_free.cpu().item()) # summ_writer.summ_scalar('pri/acc_total%s' % suffix, acc_total.cpu().item()) # # vis # summ_writer.summ_pri('pri/pri_g%s' % suffix, pri_g, reduce_axes=[2,3]) # summ_writer.summ_pri('pri/free_g%s' % suffix, free_g, reduce_axes=[2,3]) # summ_writer.summ_pri('pri/pri_e%s' % suffix, pri_e, reduce_axes=[2,3]) # summ_writer.summ_pri('pri/valid%s' % suffix, valid, reduce_axes=[2,3]) # prob_loss = self.compute_loss(energy_vol, pri_g, free_g, valid, summ_writer) # total_loss = utils_misc.add_loss('pri/prob_loss%s' % suffix, total_loss, prob_loss, hyp.pri_coeff, summ_writer) return total_loss #, pri_e