def forward(self, feat, obj_g=None, bkg_g=None, valid_g=None, summ_writer=None): total_loss = torch.tensor(0.0).cuda() sub_e_ = self.net(feat) sub_e = F.sigmoid(sub_e_) sub_e_binary = torch.round(sub_e) # smooth loss dz, dy, dx = utils_basic.gradient3D(sub_e_, absolute=True) smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('sub/smooth_loss', total_loss, smooth_loss, hyp.sub_smooth_coeff, summ_writer) if obj_g is not None: # # collect some accuracy stats # pos_match = sub_g*torch.eq(sub_e_binary, sub_g).float() # neg_match = (1.0 - sub_g)*torch.eq(1.0-sub_e_binary, 1.0 - sub_g).float() # either_match = torch.clamp(pos_match+neg_match, 0.0, 1.0) # either_have = sub_g.clone() # acc_pos = utils_basic.reduce_masked_mean(pos_match, sub_g*valid) # acc_neg = utils_basic.reduce_masked_mean(neg_match, (1.0-sub_g)*valid) # acc_total = utils_basic.reduce_masked_mean(either_match, either_have*valid) # acc_bal = (acc_pos + acc_neg)*0.5 # summ_writer.summ_scalar('unscaled_sub/acc_pos', acc_pos.cpu().item()) # summ_writer.summ_scalar('unscaled_sub/acc_neg', acc_neg.cpu().item()) # summ_writer.summ_scalar('unscaled_sub/acc_total', acc_total.cpu().item()) # summ_writer.summ_scalar('unscaled_sub/acc_bal', acc_bal.cpu().item()) prob_loss = self.compute_loss(sub_e_, obj_g, bkg_g, valid_g, summ_writer) # prob_loss = self.compute_loss(sub_e_, sub_g, (1.0 - sub_g), valid, summ_writer) total_loss = utils_misc.add_loss('sub/prob_loss', total_loss, prob_loss, hyp.sub_coeff, summ_writer) # if summ_writer is not None: # if sub_g is not None: # summ_writer.summ_occ('sub/sub_g', sub_g) # summ_writer.summ_oned('sub/sub_g_', sub_g, bev=True, norm=False) # summ_writer.summ_occ('sub/sub_e', sub_e) # summ_writer.summ_oned('sub/sub_e', sub_e, bev=True, norm=False) return total_loss, sub_e
def forward(self, feat, summ_writer=None): total_loss = torch.tensor(0.0).cuda() # B, C, Z, Y, X = list(feat.shape) feat = self.net(feat) # smooth loss dz, dy, dx = utils_basic.gradient3D(feat, absolute=True) smooth_vox = torch.mean(dx+dy+dz, dim=1, keepdims=True) summ_writer.summ_oned('up3D/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('up3D/smooth_loss', total_loss, smooth_loss, hyp.up3D_smooth_coeff, summ_writer) # feat = utils_basic.l2_normalize(feat, dim=1) # print('feat', feat.shape) if summ_writer is not None: summ_writer.summ_feat('up3D/feat_output', feat, pca=True) return total_loss, feat
def forward(self, feat, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, Z, Y, X = list(feat.shape) mask = (feat[:,0:1] > 0.0).float() # if summ_writer is not None: # summ_writer.summ_feat('feat3D/feat_mask', mask, pca=False) if summ_writer is not None: summ_writer.summ_feat('feat3D/feat_input', feat, pca=(C>3)) feat = self.net(feat) mask = torch.ones_like(feat[:,0:1]) # smooth loss dz, dy, dx = utils_basic.gradient3D(feat, absolute=True) smooth_vox = torch.mean(dz+dy+dx, dim=1, keepdims=True) if summ_writer is not None: summ_writer.summ_oned('feat3D/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('feat3D/smooth_loss', total_loss, smooth_loss, hyp.feat3D_smooth_coeff, summ_writer) feat = utils_basic.l2_normalize(feat, dim=1) if hyp.feat3D_sparse: feat = feat * mask if summ_writer is not None: summ_writer.summ_feat('feat3D/feat_output', feat, pca=True) # summ_writer.summ_feat('feat3D/feat_mask', mask, pca=False) # if hyp.feat3D_skip: # feat = feat[:,:, # self.crop[0]:-self.crop[0], # self.crop[1]:-self.crop[1], # self.crop[2]:-self.crop[2]] # mask = mask[:,:, # self.crop[0]:-self.crop[0], # self.crop[1]:-self.crop[1], # self.crop[2]:-self.crop[2]] return total_loss, feat, mask
def forward(self, feat0, feat1, flow_g=None, mask_g=None, summ_writer=None): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat0.shape) utils_basic.assert_same_shape(feat0, feat1) # feats = torch.cat([feat0, feat1], dim=0) # feats = self.compressor(feats) # feats = utils_basic.l2_normalize(feats, dim=1) # feat0, feat1 = feats[:B], feats[B:] flow_total = torch.zeros([B, 3, D, H, W]).float().cuda() feat1_aligned = feat1.clone() # summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0, [feat0, feat1_aligned]) feat_diff = torch.mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True)) utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0, summ_writer) for sc in self.scales: flow = self.generate_flow(feat0, feat1_aligned, sc) mask = torch.zeros_like(flow[:, 0:1]) # print('mask', mask.shape) mask[:, :, self.max_disp:-self.max_disp, self.max_disp:-self.max_disp, self.max_disp:-self.max_disp] = 1.0 flow = flow * mask flow_total = flow_total + flow # compositional LK: warp the original thing using the cumulative flow feat1_aligned = utils_samp.backwarp_using_3D_flow( feat1, flow_total) valid1_region = utils_samp.backwarp_using_3D_flow( torch.ones_like(feat1[:, 0:1]), flow_total) # summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc, [feat0, feat1_aligned], # valids=[torch.ones_like(valid1_region), valid1_region]) feat_diff = utils_basic.reduce_masked_mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True), valid1_region) utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff, 0, summ_writer) if flow_g is not None: # ok done inference # now for losses/metrics: l1_diff_3chan = self.smoothl1(flow_total, flow_g) l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True) l2_diff_3chan = self.mse(flow_total, flow_g) l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True) nonzero_mask = ((torch.sum(torch.abs(flow_g), axis=1, keepdim=True) > 0.01).float()) * mask_g * mask yeszero_mask = (1.0 - nonzero_mask) * mask_g * mask l1_loss = utils_basic.reduce_masked_mean(l1_diff, mask_g * mask) l2_loss = utils_basic.reduce_masked_mean(l2_diff, mask_g * mask) l1_loss_nonzero = utils_basic.reduce_masked_mean( l1_diff, nonzero_mask) l1_loss_yeszero = utils_basic.reduce_masked_mean( l1_diff, yeszero_mask) l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5 l2_loss_nonzero = utils_basic.reduce_masked_mean( l2_diff, nonzero_mask) l2_loss_yeszero = utils_basic.reduce_masked_mean( l2_diff, yeszero_mask) l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5 # clip = np.squeeze(torch.max(torch.abs(torch.mean(flow_g[0], dim=0))).detach().cpu().numpy()).item() clip = 3.0 if summ_writer is not None: # summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc, flow_total*mask_g, clip=clip) summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc, flow_total, clip=clip) summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc, flow_g, clip=clip) summ_writer.summ_oned('flow/mask_%.2f' % sc, mask, bev=True, norm=False) summ_writer.summ_feat('flow/flow_e_pca_%.2f' % sc, flow_total, pca=True) utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_balanced', 0, l1_loss_balanced, 0, summ_writer) total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, l1_loss, hyp.flow_l1_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/l2_loss', total_loss, l2_loss, hyp.flow_l2_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/warp', total_loss, feat_diff, hyp.flow_warp_coeff, summ_writer) # smooth loss dx, dy, dz = utils_basic.gradient3D(flow_total, absolute=True) smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True) if summ_writer is not None: summ_writer.summ_oned('flow/smooth_loss', torch.mean(smooth_vox, dim=3)) smooth_loss = torch.mean(smooth_vox) total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss, smooth_loss, hyp.flow_smooth_coeff, summ_writer) return total_loss, flow_total
def gradient3DForBboxFace(self, emb3D_scenes, bbox, scores): # emb3D_scenes should be B x C x D x H x W dz_batch, dy_batch, dx_batch = utils_basic.gradient3D(emb3D_scenes, absolute=False, square=False) bbox = torch.clamp(bbox, min=0) sizes_val = [hyp.Z2 - 1, hyp.Y2 - 1, hyp.X2 - 1] gs_loss_list = [] # gradient smoothness loss for index_batch, emb_scene in enumerate(emb3D_scenes): gsloss = 0 dz, dy, dx = dz_batch[index_batch:index_batch + 1], dy_batch[ index_batch:index_batch + 1], dx_batch[index_batch:index_batch + 1] for index_box, box in enumerate(bbox[index_batch]): if scores[index_batch][index_box] > 0: lower, upper = torch.unbind(box) lower = [torch.floor(i).to(torch.int32) for i in lower] upper = [torch.ceil(i).to(torch.int32) for i in upper] xmin, ymin, zmin = [max(i, 0) for i in lower] xmax, ymax, zmax = [ min(i, sizes_val[index]) for index, i in enumerate(upper) ] #zmin face gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmin, zmin + 1, ymin, ymax, xmin, xmax) if zmin < sizes_val[0]: gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmin + 1, zmin + 2, ymin, ymax, xmin, xmax) #zmax face gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmax, zmax + 1, ymin, ymax, xmin, xmax) if zmax < sizes_val[0]: gsloss += self.get_gradient_loss_on_bbox_surface( dz, zmax + 1, zmax + 2, ymin, ymax, xmin, xmax) #ymin face gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymin, ymin + 1, xmin, xmax) if ymin < sizes_val[1]: gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymin + 1, ymin + 2, xmin, xmax) #ymax face gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymax, ymax + 1, xmin, xmax) if ymax < sizes_val[1]: gsloss += self.get_gradient_loss_on_bbox_surface( dy, zmin, zmax, ymax + 1, ymax + 2, xmin, xmax) #xmin face gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmin, xmin + 1) if xmin < sizes_val[2]: gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmin + 1, xmin + 2) #xmax face gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmax, xmax + 1) if xmax < sizes_val[2]: gsloss += self.get_gradient_loss_on_bbox_surface( dx, zmin, zmax, ymin, ymax, xmax + 1, xmax + 2) gs_loss_list.append(gsloss) gsloss = torch.mean(torch.tensor(gs_loss_list)) return gsloss
def forward(self, emb_e, emb_g, vis_g, summ_writer): total_loss = torch.tensor(0.0).cuda() if torch.isnan(emb_e).any() or torch.isnan(emb_g).any(): assert (False) B, C, D, H, W = list(emb_e.shape) # put channels on the end emb_e_vec = emb_e.permute(0, 2, 3, 4, 1).reshape(B, D * H * W, C) emb_g_vec = emb_g.permute(0, 2, 3, 4, 1).reshape(B, D * H * W, C) # vis_e_vec = vis_e.permute(0,2,3,4,1).reshape(B, D*H*W, 1) vis_g_vec = vis_g.permute(0, 2, 3, 4, 1).reshape(B, D * H * W, 1) # ensure they are both nonzero, else we probably masked or warped something valid_vec_e = 1.0 - (emb_e_vec == 0).all(dim=2, keepdim=True).float() valid_vec_g = 1.0 - (emb_g_vec == 0).all(dim=2, keepdim=True).float() valid_vec = valid_vec_e * valid_vec_g # vis_e_vec *= valid_vec vis_g_vec *= valid_vec valid_g = 1.0 - (emb_g == 0).all(dim=1, keepdim=True).float() assert (self.num_samples < (B * D * H * W)) # we will take num_samples from each one # ~18% of vis_e is on # ~25% of vis_g is on # print('it looks like %.2f of vis_e is 1' % (torch.sum(vis_e_vec).cpu()/len(vis_g_vec))) # print('it looks like %.2f of vis_g is 1' % (torch.sum(vis_g_vec).cpu()/len(vis_g_vec))) # where g is valid, we use it as reference and pull up e margin_loss = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec.detach(), vis_g_vec, 'g', True, summ_writer) l2_loss = reduce_masked_mean( sql2_on_axis(emb_e - emb_g.detach(), 1, keepdim=True), vis_g) total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer) total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer) # # where e is valid, we use it as reference and pull up g # margin_loss_e = self.compute_margin_loss(B, C, D, H, W, emb_e_vec.detach(), emb_g_vec, vis_e_vec, 'e', True, summ_writer) # l2_loss_e = reduce_masked_mean(sql2_on_axis(emb_e.detach()-emb_g, 1, keepdim=True), vis_e) # # where g is valid, we use it as reference and pull up e # margin_loss_g = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec.detach(), vis_g_vec, 'g', True, summ_writer) # l2_loss_g = reduce_masked_mean(sql2_on_axis(emb_e-emb_g.detach(), 1, keepdim=True), vis_g) # # where both are valid OR neither is valid, we pull them together # vis_both_or_neither_vec = torch.clamp(vis_e_vec*vis_g_vec + (1.0-vis_e_vec)*(1.0-vis_g_vec), 0, 1) # vis_both_or_neither = torch.clamp(vis_e*vis_g + (1.0-vis_e)*(1.0-vis_g), 0, 1) # margin_loss_n = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec, vis_both_or_neither_vec, 'n', True, summ_writer) # l2_loss_n = reduce_masked_mean(sql2_on_axis(emb_e-emb_g, 1, keepdim=True), vis_both_or_neither) # margin_loss = (margin_loss_e + margin_loss_g + margin_loss_n)/3.0 # l2_loss = (l2_loss_e + l2_loss_g + l2_loss_n)/3.0 # total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer) # total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer) l2_loss_im = torch.mean(sql2_on_axis(emb_e - emb_g, 1, keepdim=True), dim=3) summ_writer.summ_oned('emb3D/emb_3D_l2_loss', l2_loss_im) dz, dy, dx = utils_basic.gradient3D(emb_g, absolute=True) smooth_loss = torch.sum(dz + dy + dx, dim=1, keepdim=True) smooth_loss_im = torch.mean(smooth_loss, dim=3) summ_writer.summ_oned('emb3D/emb_3D_smooth_loss', smooth_loss_im) emb_smooth_loss = reduce_masked_mean(smooth_loss, valid_g) total_loss = utils_misc.add_loss('emb3D/emb_3D_smooth_loss', total_loss, emb_smooth_loss, hyp.emb_3D_smooth_coeff, summ_writer) summ_writer.summ_feats('emb3D/embs_3D', [emb_e, emb_g], pca=True) return total_loss
def forward(self, feat0, feat1, flow_g, mask_g, is_synth, summ_writer): total_loss = torch.tensor(0.0).cuda() B, C, D, H, W = list(feat0.shape) utils_basic.assert_same_shape(feat0, feat1) # feats = torch.cat([feat0, feat1], dim=0) # feats = self.compressor(feats) # feats = utils_basic.l2_normalize(feats, dim=1) # feat0, feat1 = feats[:B], feats[B:] flow_total_forw = torch.zeros_like(flow_g) flow_total_back = torch.zeros_like(flow_g) feat0_aligned = feat0.clone() feat1_aligned = feat1.clone() # cycle_losses = [] # l1_losses = [] # torch does not like it when we overwrite, so let's pre-allocate l1_loss_cumu = torch.tensor(0.0).cuda() l2_loss_cumu = torch.tensor(0.0).cuda() warp_loss_cumu = torch.tensor(0.0).cuda() summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0, [feat0, feat1_aligned]) feat_diff = torch.mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True)) utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0, summ_writer) # print('feat0, feat1_aligned, mask') # print(feat0.shape) # print(feat1_aligned.shape) # print(mask_g.shape) hinge_loss_vox = utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True) hinge_loss_vox = F.relu(0.2 - hinge_loss_vox) summ_writer.summ_oned('flow/hinge_loss', torch.mean(hinge_loss_vox, dim=3)) hinge_mask_vox = (torch.sum(torch.abs(flow_g), dim=1, keepdim=True) > 1.0).float() summ_writer.summ_oned('flow/hinge_mask', torch.mean(hinge_mask_vox, dim=3), norm=False) # hinge_loss = torch.mean(hinge_loss_vox) hinge_loss = utils_basic.reduce_masked_mean(hinge_loss_vox, hinge_mask_vox) total_loss = utils_misc.add_loss('flow/hinge', total_loss, hinge_loss, hyp.flow_hinge_coeff, summ_writer) for sc in self.scales: # flow_forw, new_feat0, new_feat1 = self.generate_flow(feat0, feat1_aligned, sc) # flow_back, new_feat1, new_feat1 = self.generate_flow(feat1, feat0_aligned, sc) # flow_forw, heat = self.generate_flow(feat0, feat1_aligned, sc) flow_forw = self.generate_flow(feat0, feat1_aligned, sc) flow_back = self.generate_flow(feat1, feat0_aligned, sc) flow_total_forw = flow_total_forw + flow_forw flow_total_back = flow_total_back + flow_back # compositional LK: warp the original thing using the cumulative flow feat1_aligned = utils_samp.backwarp_using_3D_flow( feat1, flow_total_forw) feat0_aligned = utils_samp.backwarp_using_3D_flow( feat0, flow_total_back) valid1_region = utils_samp.backwarp_using_3D_flow( torch.ones_like(feat1[:, 0:1]), flow_total_forw) valid0_region = utils_samp.backwarp_using_3D_flow( torch.ones_like(feat0[:, 0:1]), flow_total_forw) summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc, [feat0, feat1_aligned], valids=[valid0_region, valid1_region]) # summ_writer.summ_oned('flow/mean_heat_%.2f' % sc, torch.mean(heat, dim=3)) # feat_diff = torch.mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True)) feat_diff = utils_basic.reduce_masked_mean( utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True), valid1_region * valid0_region) utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff, 0, summ_writer) if sc == 1.0: warp_loss_cumu = warp_loss_cumu + feat_diff * sc l1_diff_3chan = self.smoothl1(flow_total_forw, flow_g) l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True) l2_diff_3chan = self.mse(flow_total_forw, flow_g) l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True) nonzero_mask = ( (torch.sum(torch.abs(flow_g), axis=1, keepdim=True) > 0.01).float()) * mask_g yeszero_mask = (1.0 - nonzero_mask) * mask_g l1_loss_nonzero = utils_basic.reduce_masked_mean( l1_diff, nonzero_mask) l1_loss_yeszero = utils_basic.reduce_masked_mean( l1_diff, yeszero_mask) l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5 l2_loss_nonzero = utils_basic.reduce_masked_mean( l2_diff, nonzero_mask) l2_loss_yeszero = utils_basic.reduce_masked_mean( l2_diff, yeszero_mask) l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5 # l1_loss_cumu = l1_loss_cumu + l1_loss_balanced*sc l1_loss_cumu = l1_loss_cumu + l1_loss_balanced * sc l2_loss_cumu = l2_loss_cumu + l2_loss_balanced * sc # warp flow flow_back_aligned_to_forw = utils_samp.backwarp_using_3D_flow( flow_total_back, flow_total_forw.detach()) flow_forw_aligned_to_back = utils_samp.backwarp_using_3D_flow( flow_total_forw, flow_total_back.detach()) cancelled_flow_forw = flow_total_forw + flow_back_aligned_to_forw cancelled_flow_back = flow_total_back + flow_forw_aligned_to_back cycle_forw = self.smoothl1_mean( cancelled_flow_forw, torch.zeros_like(cancelled_flow_forw)) cycle_back = self.smoothl1_mean( cancelled_flow_back, torch.zeros_like(cancelled_flow_back)) cycle_loss = cycle_forw + cycle_back total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss, cycle_loss, hyp.flow_cycle_coeff, summ_writer) summ_writer.summ_3D_flow('flow/flow_e_forw_%.2f' % sc, flow_total_forw * mask_g, clip=0.0) summ_writer.summ_3D_flow('flow/flow_e_back_%.2f' % sc, flow_total_back, clip=0.0) summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc, flow_g, clip=0.0) utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero, 0, summ_writer) utils_misc.add_loss('flow/l1_loss_balanced', 0, l1_loss_balanced, 0, summ_writer) utils_misc.add_loss('flow/l2_loss_balanced', 0, l2_loss_balanced, 0, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, l1_loss, hyp.flow_l1_coeff*(sc==1.0), summ_writer) if is_synth: total_loss = utils_misc.add_loss('flow/synth_l1_cumu', total_loss, l1_loss_cumu, hyp.flow_synth_l1_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/synth_l2_cumu', total_loss, l2_loss_cumu, hyp.flow_synth_l2_coeff, summ_writer) else: total_loss = utils_misc.add_loss('flow/l1_cumu', total_loss, l1_loss_cumu, hyp.flow_l1_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/l2_cumu', total_loss, l2_loss_cumu, hyp.flow_l2_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/warp', total_loss, feat_diff, hyp.flow_warp_coeff, summ_writer) total_loss = utils_misc.add_loss('flow/warp_cumu', total_loss, warp_loss_cumu, hyp.flow_warp_coeff, summ_writer) # feat1_aligned = utils_samp.backwarp_using_3D_flow(feat1, flow_g) # valid_region = utils_samp.backwarp_using_3D_flow(torch.ones_like(feat1[:,0:1]), flow_g) # summ_writer.summ_feats('flow/feats_aligned_g', [feat0, feat1_aligned], # valids=[valid_region, valid_region]) # feat_diff = utils_basic.reduce_masked_mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True), valid_region) # total_loss = utils_misc.add_loss('flow/warp_g', total_loss, feat_diff, hyp.flow_warp_g_coeff, summ_writer) # # hinge_loss_vox = F.relu(0.2 - hinge_loss_vox) # total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss, torch.sum(torch.stack(cycle_losses)), hyp.flow_cycle_coeff, summ_writer) # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, torch.sum(torch.stack(l1_losses)), hyp.flow_l1_coeff, summ_writer) # smooth loss dx, dy, dz = utils_basic.gradient3D(flow_total_forw, absolute=True) smooth_vox_forw = torch.mean(dx + dy + dx, dim=1, keepdims=True) dx, dy, dz = utils_basic.gradient3D(flow_total_back, absolute=True) smooth_vox_back = torch.mean(dx + dy + dx, dim=1, keepdims=True) summ_writer.summ_oned('flow/smooth_loss_forw', torch.mean(smooth_vox_forw, dim=3)) smooth_loss = torch.mean((smooth_vox_forw + smooth_vox_back) * 0.5) total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss, smooth_loss, hyp.flow_smooth_coeff, summ_writer) # flow_e = F.sigmoid(flow_e_) # flow_e_binary = torch.round(flow_e) # # collect some accuracy stats # flow_match = flow_g*torch.eq(flow_e_binary, flow_g).float() # free_match = free_g*torch.eq(1.0-flow_e_binary, free_g).float() # either_match = torch.clamp(flow_match+free_match, 0.0, 1.0) # either_have = torch.clamp(flow_g+free_g, 0.0, 1.0) # acc_flow = reduce_masked_mean(flow_match, flow_g*valid) # acc_free = reduce_masked_mean(free_match, free_g*valid) # acc_total = reduce_masked_mean(either_match, either_have*valid) # summ_writer.summ_scalar('flow/acc_flow', acc_flow.cpu().item()) # summ_writer.summ_scalar('flow/acc_free', acc_free.cpu().item()) # summ_writer.summ_scalar('flow/acc_total', acc_total.cpu().item()) # # vis # summ_writer.summ_flow('flow/flow_g', flow_g) # summ_writer.summ_flow('flow/free_g', free_g) # summ_writer.summ_flow('flow/flow_e', flow_e) # summ_writer.summ_flow('flow/valid', valid) # prob_loss = self.compute_loss(flow_e_, flow_g, free_g, valid, summ_writer) # total_loss = utils_misc.add_loss('flow/prob_loss', total_loss, prob_loss, hyp.flow_coeff, summ_writer) # return total_loss, flow_e return total_loss, flow_total_forw