Пример #1
0
    def forward(self,
                feat,
                obj_g=None,
                bkg_g=None,
                valid_g=None,
                summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()

        sub_e_ = self.net(feat)
        sub_e = F.sigmoid(sub_e_)
        sub_e_binary = torch.round(sub_e)

        # smooth loss
        dz, dy, dx = utils_basic.gradient3D(sub_e_, absolute=True)
        smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True)
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('sub/smooth_loss', total_loss,
                                         smooth_loss, hyp.sub_smooth_coeff,
                                         summ_writer)

        if obj_g is not None:

            # # collect some accuracy stats
            # pos_match = sub_g*torch.eq(sub_e_binary, sub_g).float()
            # neg_match = (1.0 - sub_g)*torch.eq(1.0-sub_e_binary, 1.0 - sub_g).float()
            # either_match = torch.clamp(pos_match+neg_match, 0.0, 1.0)
            # either_have = sub_g.clone()
            # acc_pos = utils_basic.reduce_masked_mean(pos_match, sub_g*valid)
            # acc_neg = utils_basic.reduce_masked_mean(neg_match, (1.0-sub_g)*valid)
            # acc_total = utils_basic.reduce_masked_mean(either_match, either_have*valid)
            # acc_bal = (acc_pos + acc_neg)*0.5

            # summ_writer.summ_scalar('unscaled_sub/acc_pos', acc_pos.cpu().item())
            # summ_writer.summ_scalar('unscaled_sub/acc_neg', acc_neg.cpu().item())
            # summ_writer.summ_scalar('unscaled_sub/acc_total', acc_total.cpu().item())
            # summ_writer.summ_scalar('unscaled_sub/acc_bal', acc_bal.cpu().item())

            prob_loss = self.compute_loss(sub_e_, obj_g, bkg_g, valid_g,
                                          summ_writer)
            # prob_loss = self.compute_loss(sub_e_, sub_g, (1.0 - sub_g), valid, summ_writer)
            total_loss = utils_misc.add_loss('sub/prob_loss', total_loss,
                                             prob_loss, hyp.sub_coeff,
                                             summ_writer)

        # if summ_writer is not None:
        #     if sub_g is not None:
        #         summ_writer.summ_occ('sub/sub_g', sub_g)
        #         summ_writer.summ_oned('sub/sub_g_', sub_g, bev=True, norm=False)
        #     summ_writer.summ_occ('sub/sub_e', sub_e)
        #     summ_writer.summ_oned('sub/sub_e', sub_e, bev=True, norm=False)
        return total_loss, sub_e
Пример #2
0
    def forward(self, feat, summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()
        # B, C, Z, Y, X = list(feat.shape)

        feat = self.net(feat)

        # smooth loss
        dz, dy, dx = utils_basic.gradient3D(feat, absolute=True)
        smooth_vox = torch.mean(dx+dy+dz, dim=1, keepdims=True)
        summ_writer.summ_oned('up3D/smooth_loss', torch.mean(smooth_vox, dim=3))
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('up3D/smooth_loss', total_loss, smooth_loss, hyp.up3D_smooth_coeff, summ_writer)

        # feat = utils_basic.l2_normalize(feat, dim=1)
        # print('feat', feat.shape)
        
        if summ_writer is not None:
            summ_writer.summ_feat('up3D/feat_output', feat, pca=True)
        return total_loss, feat
Пример #3
0
    def forward(self, feat, summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()
        B, C, Z, Y, X = list(feat.shape)

        mask = (feat[:,0:1] > 0.0).float()
        # if summ_writer is not None:
        #     summ_writer.summ_feat('feat3D/feat_mask', mask, pca=False)
        
        if summ_writer is not None:
            summ_writer.summ_feat('feat3D/feat_input', feat, pca=(C>3))

        feat = self.net(feat)
        mask = torch.ones_like(feat[:,0:1])

        # smooth loss
        dz, dy, dx = utils_basic.gradient3D(feat, absolute=True)
        smooth_vox = torch.mean(dz+dy+dx, dim=1, keepdims=True)
        if summ_writer is not None:
            summ_writer.summ_oned('feat3D/smooth_loss', torch.mean(smooth_vox, dim=3))
        smooth_loss = torch.mean(smooth_vox)
        total_loss = utils_misc.add_loss('feat3D/smooth_loss', total_loss, smooth_loss, hyp.feat3D_smooth_coeff, summ_writer)
            
        feat = utils_basic.l2_normalize(feat, dim=1)
        if hyp.feat3D_sparse:
            feat = feat * mask
        
        if summ_writer is not None:
            summ_writer.summ_feat('feat3D/feat_output', feat, pca=True)
            # summ_writer.summ_feat('feat3D/feat_mask', mask, pca=False)
            
        # if hyp.feat3D_skip:
        #     feat = feat[:,:,
        #                 self.crop[0]:-self.crop[0],
        #                 self.crop[1]:-self.crop[1],
        #                 self.crop[2]:-self.crop[2]]
        #     mask = mask[:,:,
        #                 self.crop[0]:-self.crop[0],
        #                 self.crop[1]:-self.crop[1],
        #                 self.crop[2]:-self.crop[2]]
            
        return total_loss, feat, mask
Пример #4
0
    def forward(self,
                feat0,
                feat1,
                flow_g=None,
                mask_g=None,
                summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()

        B, C, D, H, W = list(feat0.shape)
        utils_basic.assert_same_shape(feat0, feat1)

        # feats = torch.cat([feat0, feat1], dim=0)
        # feats = self.compressor(feats)
        # feats = utils_basic.l2_normalize(feats, dim=1)
        # feat0, feat1 = feats[:B], feats[B:]

        flow_total = torch.zeros([B, 3, D, H, W]).float().cuda()
        feat1_aligned = feat1.clone()

        # summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0, [feat0, feat1_aligned])
        feat_diff = torch.mean(
            utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True))
        utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0,
                            summ_writer)

        for sc in self.scales:
            flow = self.generate_flow(feat0, feat1_aligned, sc)

            mask = torch.zeros_like(flow[:, 0:1])
            # print('mask', mask.shape)
            mask[:, :, self.max_disp:-self.max_disp,
                 self.max_disp:-self.max_disp,
                 self.max_disp:-self.max_disp] = 1.0
            flow = flow * mask

            flow_total = flow_total + flow

            # compositional LK: warp the original thing using the cumulative flow
            feat1_aligned = utils_samp.backwarp_using_3D_flow(
                feat1, flow_total)
            valid1_region = utils_samp.backwarp_using_3D_flow(
                torch.ones_like(feat1[:, 0:1]), flow_total)
            # summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc, [feat0, feat1_aligned],
            #                        valids=[torch.ones_like(valid1_region), valid1_region])
            feat_diff = utils_basic.reduce_masked_mean(
                utils_basic.l2_on_axis((feat1_aligned - feat0),
                                       1,
                                       keepdim=True), valid1_region)
            utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff,
                                0, summ_writer)

        if flow_g is not None:
            # ok done inference
            # now for losses/metrics:
            l1_diff_3chan = self.smoothl1(flow_total, flow_g)
            l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True)
            l2_diff_3chan = self.mse(flow_total, flow_g)
            l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True)

            nonzero_mask = ((torch.sum(torch.abs(flow_g), axis=1, keepdim=True)
                             > 0.01).float()) * mask_g * mask
            yeszero_mask = (1.0 - nonzero_mask) * mask_g * mask
            l1_loss = utils_basic.reduce_masked_mean(l1_diff, mask_g * mask)
            l2_loss = utils_basic.reduce_masked_mean(l2_diff, mask_g * mask)
            l1_loss_nonzero = utils_basic.reduce_masked_mean(
                l1_diff, nonzero_mask)
            l1_loss_yeszero = utils_basic.reduce_masked_mean(
                l1_diff, yeszero_mask)
            l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5
            l2_loss_nonzero = utils_basic.reduce_masked_mean(
                l2_diff, nonzero_mask)
            l2_loss_yeszero = utils_basic.reduce_masked_mean(
                l2_diff, yeszero_mask)
            l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5

            # clip = np.squeeze(torch.max(torch.abs(torch.mean(flow_g[0], dim=0))).detach().cpu().numpy()).item()
            clip = 3.0
            if summ_writer is not None:
                # summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc, flow_total*mask_g, clip=clip)
                summ_writer.summ_3D_flow('flow/flow_e_%.2f' % sc,
                                         flow_total,
                                         clip=clip)
                summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc,
                                         flow_g,
                                         clip=clip)
                summ_writer.summ_oned('flow/mask_%.2f' % sc,
                                      mask,
                                      bev=True,
                                      norm=False)
                summ_writer.summ_feat('flow/flow_e_pca_%.2f' % sc,
                                      flow_total,
                                      pca=True)

            utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero, 0,
                                summ_writer)
            utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero, 0,
                                summ_writer)
            utils_misc.add_loss('flow/l1_loss_balanced', 0, l1_loss_balanced,
                                0, summ_writer)

            total_loss = utils_misc.add_loss('flow/l1_loss', total_loss,
                                             l1_loss, hyp.flow_l1_coeff,
                                             summ_writer)
            total_loss = utils_misc.add_loss('flow/l2_loss', total_loss,
                                             l2_loss, hyp.flow_l2_coeff,
                                             summ_writer)

            total_loss = utils_misc.add_loss('flow/warp', total_loss,
                                             feat_diff, hyp.flow_warp_coeff,
                                             summ_writer)

            # smooth loss
            dx, dy, dz = utils_basic.gradient3D(flow_total, absolute=True)
            smooth_vox = torch.mean(dx + dy + dz, dim=1, keepdims=True)
            if summ_writer is not None:
                summ_writer.summ_oned('flow/smooth_loss',
                                      torch.mean(smooth_vox, dim=3))
            smooth_loss = torch.mean(smooth_vox)
            total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss,
                                             smooth_loss,
                                             hyp.flow_smooth_coeff,
                                             summ_writer)

        return total_loss, flow_total
    def gradient3DForBboxFace(self, emb3D_scenes, bbox, scores):
        # emb3D_scenes should be B x C x D x H x W
        dz_batch, dy_batch, dx_batch = utils_basic.gradient3D(emb3D_scenes,
                                                              absolute=False,
                                                              square=False)

        bbox = torch.clamp(bbox, min=0)
        sizes_val = [hyp.Z2 - 1, hyp.Y2 - 1, hyp.X2 - 1]

        gs_loss_list = []  # gradient smoothness loss

        for index_batch, emb_scene in enumerate(emb3D_scenes):
            gsloss = 0
            dz, dy, dx = dz_batch[index_batch:index_batch + 1], dy_batch[
                index_batch:index_batch +
                1], dx_batch[index_batch:index_batch + 1]
            for index_box, box in enumerate(bbox[index_batch]):
                if scores[index_batch][index_box] > 0:

                    lower, upper = torch.unbind(box)
                    lower = [torch.floor(i).to(torch.int32) for i in lower]
                    upper = [torch.ceil(i).to(torch.int32) for i in upper]
                    xmin, ymin, zmin = [max(i, 0) for i in lower]

                    xmax, ymax, zmax = [
                        min(i, sizes_val[index])
                        for index, i in enumerate(upper)
                    ]

                    #zmin face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dz, zmin, zmin + 1, ymin, ymax, xmin, xmax)
                    if zmin < sizes_val[0]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dz, zmin + 1, zmin + 2, ymin, ymax, xmin, xmax)

                    #zmax face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dz, zmax, zmax + 1, ymin, ymax, xmin, xmax)
                    if zmax < sizes_val[0]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dz, zmax + 1, zmax + 2, ymin, ymax, xmin, xmax)

                    #ymin face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dy, zmin, zmax, ymin, ymin + 1, xmin, xmax)
                    if ymin < sizes_val[1]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dy, zmin, zmax, ymin + 1, ymin + 2, xmin, xmax)

                    #ymax face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dy, zmin, zmax, ymax, ymax + 1, xmin, xmax)
                    if ymax < sizes_val[1]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dy, zmin, zmax, ymax + 1, ymax + 2, xmin, xmax)

                    #xmin face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dx, zmin, zmax, ymin, ymax, xmin, xmin + 1)
                    if xmin < sizes_val[2]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dx, zmin, zmax, ymin, ymax, xmin + 1, xmin + 2)

                    #xmax face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dx, zmin, zmax, ymin, ymax, xmax, xmax + 1)
                    if xmax < sizes_val[2]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dx, zmin, zmax, ymin, ymax, xmax + 1, xmax + 2)

            gs_loss_list.append(gsloss)

        gsloss = torch.mean(torch.tensor(gs_loss_list))
        return gsloss
    def forward(self, emb_e, emb_g, vis_g, summ_writer):
        total_loss = torch.tensor(0.0).cuda()

        if torch.isnan(emb_e).any() or torch.isnan(emb_g).any():
            assert (False)

        B, C, D, H, W = list(emb_e.shape)
        # put channels on the end
        emb_e_vec = emb_e.permute(0, 2, 3, 4, 1).reshape(B, D * H * W, C)
        emb_g_vec = emb_g.permute(0, 2, 3, 4, 1).reshape(B, D * H * W, C)
        # vis_e_vec = vis_e.permute(0,2,3,4,1).reshape(B, D*H*W, 1)
        vis_g_vec = vis_g.permute(0, 2, 3, 4, 1).reshape(B, D * H * W, 1)

        # ensure they are both nonzero, else we probably masked or warped something
        valid_vec_e = 1.0 - (emb_e_vec == 0).all(dim=2, keepdim=True).float()
        valid_vec_g = 1.0 - (emb_g_vec == 0).all(dim=2, keepdim=True).float()
        valid_vec = valid_vec_e * valid_vec_g
        # vis_e_vec *= valid_vec
        vis_g_vec *= valid_vec
        valid_g = 1.0 - (emb_g == 0).all(dim=1, keepdim=True).float()

        assert (self.num_samples < (B * D * H * W))
        # we will take num_samples from each one

        # ~18% of vis_e is on
        # ~25% of vis_g is on
        # print('it looks like %.2f of vis_e is 1' % (torch.sum(vis_e_vec).cpu()/len(vis_g_vec)))
        # print('it looks like %.2f of vis_g is 1' % (torch.sum(vis_g_vec).cpu()/len(vis_g_vec)))

        # where g is valid, we use it as reference and pull up e
        margin_loss = self.compute_margin_loss(B, C, D, H, W, emb_e_vec,
                                               emb_g_vec.detach(), vis_g_vec,
                                               'g', True, summ_writer)
        l2_loss = reduce_masked_mean(
            sql2_on_axis(emb_e - emb_g.detach(), 1, keepdim=True), vis_g)
        total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss,
                                         margin_loss, hyp.emb_3D_ml_coeff,
                                         summ_writer)
        total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss,
                                         l2_loss, hyp.emb_3D_l2_coeff,
                                         summ_writer)

        # # where e is valid, we use it as reference and pull up g
        # margin_loss_e = self.compute_margin_loss(B, C, D, H, W, emb_e_vec.detach(), emb_g_vec, vis_e_vec, 'e', True, summ_writer)
        # l2_loss_e = reduce_masked_mean(sql2_on_axis(emb_e.detach()-emb_g, 1, keepdim=True), vis_e)
        # # where g is valid, we use it as reference and pull up e
        # margin_loss_g = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec.detach(), vis_g_vec, 'g', True, summ_writer)
        # l2_loss_g = reduce_masked_mean(sql2_on_axis(emb_e-emb_g.detach(), 1, keepdim=True), vis_g)
        # # where both are valid OR neither is valid, we pull them together
        # vis_both_or_neither_vec = torch.clamp(vis_e_vec*vis_g_vec + (1.0-vis_e_vec)*(1.0-vis_g_vec), 0, 1)
        # vis_both_or_neither = torch.clamp(vis_e*vis_g + (1.0-vis_e)*(1.0-vis_g), 0, 1)
        # margin_loss_n = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec, vis_both_or_neither_vec, 'n', True, summ_writer)
        # l2_loss_n = reduce_masked_mean(sql2_on_axis(emb_e-emb_g, 1, keepdim=True), vis_both_or_neither)
        # margin_loss = (margin_loss_e + margin_loss_g + margin_loss_n)/3.0
        # l2_loss = (l2_loss_e + l2_loss_g + l2_loss_n)/3.0
        # total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer)
        # total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer)

        l2_loss_im = torch.mean(sql2_on_axis(emb_e - emb_g, 1, keepdim=True),
                                dim=3)
        summ_writer.summ_oned('emb3D/emb_3D_l2_loss', l2_loss_im)

        dz, dy, dx = utils_basic.gradient3D(emb_g, absolute=True)
        smooth_loss = torch.sum(dz + dy + dx, dim=1, keepdim=True)
        smooth_loss_im = torch.mean(smooth_loss, dim=3)
        summ_writer.summ_oned('emb3D/emb_3D_smooth_loss', smooth_loss_im)
        emb_smooth_loss = reduce_masked_mean(smooth_loss, valid_g)
        total_loss = utils_misc.add_loss('emb3D/emb_3D_smooth_loss',
                                         total_loss, emb_smooth_loss,
                                         hyp.emb_3D_smooth_coeff, summ_writer)

        summ_writer.summ_feats('emb3D/embs_3D', [emb_e, emb_g], pca=True)
        return total_loss
Пример #7
0
    def forward(self, feat0, feat1, flow_g, mask_g, is_synth, summ_writer):
        total_loss = torch.tensor(0.0).cuda()

        B, C, D, H, W = list(feat0.shape)
        utils_basic.assert_same_shape(feat0, feat1)

        # feats = torch.cat([feat0, feat1], dim=0)
        # feats = self.compressor(feats)
        # feats = utils_basic.l2_normalize(feats, dim=1)
        # feat0, feat1 = feats[:B], feats[B:]

        flow_total_forw = torch.zeros_like(flow_g)
        flow_total_back = torch.zeros_like(flow_g)

        feat0_aligned = feat0.clone()
        feat1_aligned = feat1.clone()

        # cycle_losses = []
        # l1_losses = []

        # torch does not like it when we overwrite, so let's pre-allocate
        l1_loss_cumu = torch.tensor(0.0).cuda()
        l2_loss_cumu = torch.tensor(0.0).cuda()
        warp_loss_cumu = torch.tensor(0.0).cuda()

        summ_writer.summ_feats('flow/feats_aligned_%.2f' % 0.0,
                               [feat0, feat1_aligned])
        feat_diff = torch.mean(
            utils_basic.l2_on_axis((feat1_aligned - feat0), 1, keepdim=True))
        utils_misc.add_loss('flow/feat_align_diff_%.2f' % 0.0, 0, feat_diff, 0,
                            summ_writer)

        # print('feat0, feat1_aligned, mask')
        # print(feat0.shape)
        # print(feat1_aligned.shape)
        # print(mask_g.shape)

        hinge_loss_vox = utils_basic.l2_on_axis((feat1_aligned - feat0),
                                                1,
                                                keepdim=True)
        hinge_loss_vox = F.relu(0.2 - hinge_loss_vox)
        summ_writer.summ_oned('flow/hinge_loss',
                              torch.mean(hinge_loss_vox, dim=3))
        hinge_mask_vox = (torch.sum(torch.abs(flow_g), dim=1, keepdim=True) >
                          1.0).float()
        summ_writer.summ_oned('flow/hinge_mask',
                              torch.mean(hinge_mask_vox, dim=3),
                              norm=False)
        # hinge_loss = torch.mean(hinge_loss_vox)
        hinge_loss = utils_basic.reduce_masked_mean(hinge_loss_vox,
                                                    hinge_mask_vox)
        total_loss = utils_misc.add_loss('flow/hinge', total_loss, hinge_loss,
                                         hyp.flow_hinge_coeff, summ_writer)

        for sc in self.scales:

            # flow_forw, new_feat0, new_feat1 = self.generate_flow(feat0, feat1_aligned, sc)
            # flow_back, new_feat1, new_feat1 = self.generate_flow(feat1, feat0_aligned, sc)
            # flow_forw, heat = self.generate_flow(feat0, feat1_aligned, sc)

            flow_forw = self.generate_flow(feat0, feat1_aligned, sc)
            flow_back = self.generate_flow(feat1, feat0_aligned, sc)

            flow_total_forw = flow_total_forw + flow_forw
            flow_total_back = flow_total_back + flow_back

            # compositional LK: warp the original thing using the cumulative flow
            feat1_aligned = utils_samp.backwarp_using_3D_flow(
                feat1, flow_total_forw)
            feat0_aligned = utils_samp.backwarp_using_3D_flow(
                feat0, flow_total_back)
            valid1_region = utils_samp.backwarp_using_3D_flow(
                torch.ones_like(feat1[:, 0:1]), flow_total_forw)
            valid0_region = utils_samp.backwarp_using_3D_flow(
                torch.ones_like(feat0[:, 0:1]), flow_total_forw)

            summ_writer.summ_feats('flow/feats_aligned_%.2f' % sc,
                                   [feat0, feat1_aligned],
                                   valids=[valid0_region, valid1_region])
            # summ_writer.summ_oned('flow/mean_heat_%.2f' % sc, torch.mean(heat, dim=3))
            # feat_diff = torch.mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True))
            feat_diff = utils_basic.reduce_masked_mean(
                utils_basic.l2_on_axis((feat1_aligned - feat0),
                                       1,
                                       keepdim=True),
                valid1_region * valid0_region)
            utils_misc.add_loss('flow/feat_align_diff_%.2f' % sc, 0, feat_diff,
                                0, summ_writer)

            if sc == 1.0:

                warp_loss_cumu = warp_loss_cumu + feat_diff * sc

                l1_diff_3chan = self.smoothl1(flow_total_forw, flow_g)
                l1_diff = torch.mean(l1_diff_3chan, dim=1, keepdim=True)
                l2_diff_3chan = self.mse(flow_total_forw, flow_g)
                l2_diff = torch.mean(l2_diff_3chan, dim=1, keepdim=True)

                nonzero_mask = (
                    (torch.sum(torch.abs(flow_g), axis=1, keepdim=True) >
                     0.01).float()) * mask_g
                yeszero_mask = (1.0 - nonzero_mask) * mask_g
                l1_loss_nonzero = utils_basic.reduce_masked_mean(
                    l1_diff, nonzero_mask)
                l1_loss_yeszero = utils_basic.reduce_masked_mean(
                    l1_diff, yeszero_mask)
                l1_loss_balanced = (l1_loss_nonzero + l1_loss_yeszero) * 0.5
                l2_loss_nonzero = utils_basic.reduce_masked_mean(
                    l2_diff, nonzero_mask)
                l2_loss_yeszero = utils_basic.reduce_masked_mean(
                    l2_diff, yeszero_mask)
                l2_loss_balanced = (l2_loss_nonzero + l2_loss_yeszero) * 0.5
                # l1_loss_cumu = l1_loss_cumu + l1_loss_balanced*sc

                l1_loss_cumu = l1_loss_cumu + l1_loss_balanced * sc
                l2_loss_cumu = l2_loss_cumu + l2_loss_balanced * sc

                # warp flow
                flow_back_aligned_to_forw = utils_samp.backwarp_using_3D_flow(
                    flow_total_back, flow_total_forw.detach())
                flow_forw_aligned_to_back = utils_samp.backwarp_using_3D_flow(
                    flow_total_forw, flow_total_back.detach())

                cancelled_flow_forw = flow_total_forw + flow_back_aligned_to_forw
                cancelled_flow_back = flow_total_back + flow_forw_aligned_to_back

                cycle_forw = self.smoothl1_mean(
                    cancelled_flow_forw, torch.zeros_like(cancelled_flow_forw))
                cycle_back = self.smoothl1_mean(
                    cancelled_flow_back, torch.zeros_like(cancelled_flow_back))
                cycle_loss = cycle_forw + cycle_back
                total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss,
                                                 cycle_loss,
                                                 hyp.flow_cycle_coeff,
                                                 summ_writer)

                summ_writer.summ_3D_flow('flow/flow_e_forw_%.2f' % sc,
                                         flow_total_forw * mask_g,
                                         clip=0.0)
                summ_writer.summ_3D_flow('flow/flow_e_back_%.2f' % sc,
                                         flow_total_back,
                                         clip=0.0)
                summ_writer.summ_3D_flow('flow/flow_g_%.2f' % sc,
                                         flow_g,
                                         clip=0.0)

                utils_misc.add_loss('flow/l1_loss_nonzero', 0, l1_loss_nonzero,
                                    0, summ_writer)
                utils_misc.add_loss('flow/l1_loss_yeszero', 0, l1_loss_yeszero,
                                    0, summ_writer)
                utils_misc.add_loss('flow/l1_loss_balanced', 0,
                                    l1_loss_balanced, 0, summ_writer)
                utils_misc.add_loss('flow/l2_loss_balanced', 0,
                                    l2_loss_balanced, 0, summ_writer)

                # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer)
                # total_loss = utils_misc.add_loss('flow/l1_loss_balanced', total_loss, l1_loss_balanced, hyp.flow_l1_coeff, summ_writer)
                # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, l1_loss, hyp.flow_l1_coeff*(sc==1.0), summ_writer)

        if is_synth:
            total_loss = utils_misc.add_loss('flow/synth_l1_cumu', total_loss,
                                             l1_loss_cumu,
                                             hyp.flow_synth_l1_coeff,
                                             summ_writer)
            total_loss = utils_misc.add_loss('flow/synth_l2_cumu', total_loss,
                                             l2_loss_cumu,
                                             hyp.flow_synth_l2_coeff,
                                             summ_writer)
        else:
            total_loss = utils_misc.add_loss('flow/l1_cumu', total_loss,
                                             l1_loss_cumu, hyp.flow_l1_coeff,
                                             summ_writer)
            total_loss = utils_misc.add_loss('flow/l2_cumu', total_loss,
                                             l2_loss_cumu, hyp.flow_l2_coeff,
                                             summ_writer)

        # total_loss = utils_misc.add_loss('flow/warp', total_loss, feat_diff, hyp.flow_warp_coeff, summ_writer)
        total_loss = utils_misc.add_loss('flow/warp_cumu', total_loss,
                                         warp_loss_cumu, hyp.flow_warp_coeff,
                                         summ_writer)

        # feat1_aligned = utils_samp.backwarp_using_3D_flow(feat1, flow_g)
        # valid_region = utils_samp.backwarp_using_3D_flow(torch.ones_like(feat1[:,0:1]), flow_g)
        # summ_writer.summ_feats('flow/feats_aligned_g', [feat0, feat1_aligned],
        #                        valids=[valid_region, valid_region])
        # feat_diff = utils_basic.reduce_masked_mean(utils_basic.l2_on_axis((feat1_aligned-feat0), 1, keepdim=True), valid_region)
        # total_loss = utils_misc.add_loss('flow/warp_g', total_loss, feat_diff, hyp.flow_warp_g_coeff, summ_writer)
        # # hinge_loss_vox = F.relu(0.2 - hinge_loss_vox)

        # total_loss = utils_misc.add_loss('flow/cycle_loss', total_loss, torch.sum(torch.stack(cycle_losses)), hyp.flow_cycle_coeff, summ_writer)
        # total_loss = utils_misc.add_loss('flow/l1_loss', total_loss, torch.sum(torch.stack(l1_losses)), hyp.flow_l1_coeff, summ_writer)

        # smooth loss
        dx, dy, dz = utils_basic.gradient3D(flow_total_forw, absolute=True)
        smooth_vox_forw = torch.mean(dx + dy + dx, dim=1, keepdims=True)
        dx, dy, dz = utils_basic.gradient3D(flow_total_back, absolute=True)
        smooth_vox_back = torch.mean(dx + dy + dx, dim=1, keepdims=True)
        summ_writer.summ_oned('flow/smooth_loss_forw',
                              torch.mean(smooth_vox_forw, dim=3))
        smooth_loss = torch.mean((smooth_vox_forw + smooth_vox_back) * 0.5)
        total_loss = utils_misc.add_loss('flow/smooth_loss', total_loss,
                                         smooth_loss, hyp.flow_smooth_coeff,
                                         summ_writer)

        # flow_e = F.sigmoid(flow_e_)
        # flow_e_binary = torch.round(flow_e)

        # # collect some accuracy stats
        # flow_match = flow_g*torch.eq(flow_e_binary, flow_g).float()
        # free_match = free_g*torch.eq(1.0-flow_e_binary, free_g).float()
        # either_match = torch.clamp(flow_match+free_match, 0.0, 1.0)
        # either_have = torch.clamp(flow_g+free_g, 0.0, 1.0)
        # acc_flow = reduce_masked_mean(flow_match, flow_g*valid)
        # acc_free = reduce_masked_mean(free_match, free_g*valid)
        # acc_total = reduce_masked_mean(either_match, either_have*valid)

        # summ_writer.summ_scalar('flow/acc_flow', acc_flow.cpu().item())
        # summ_writer.summ_scalar('flow/acc_free', acc_free.cpu().item())
        # summ_writer.summ_scalar('flow/acc_total', acc_total.cpu().item())

        # # vis
        # summ_writer.summ_flow('flow/flow_g', flow_g)
        # summ_writer.summ_flow('flow/free_g', free_g)
        # summ_writer.summ_flow('flow/flow_e', flow_e)
        # summ_writer.summ_flow('flow/valid', valid)

        # prob_loss = self.compute_loss(flow_e_, flow_g, free_g, valid, summ_writer)
        # total_loss = utils_misc.add_loss('flow/prob_loss', total_loss, prob_loss, hyp.flow_coeff, summ_writer)

        # return total_loss, flow_e
        return total_loss, flow_total_forw