Beispiel #1
0
def crop_zoom_from_mem(mem, lrt, Z2, Y2, X2, additive_pad=0.1):
    # mem is B x C x Z x Y x X
    # lrt is B x 19

    B, C, Z, Y, X = list(mem.shape)
    B2, E = list(lrt.shape)

    assert (E == 19)
    assert (B == B2)

    # for each voxel in the zoom grid, i want to
    # sample a voxel from the mem

    # this puts each C-dim pixel in the image
    # along a ray in the zoomed voxelgrid

    xyz_zoom = utils_basic.gridcloud3D(B, Z2, Y2, X2, norm=False)
    # these represent the zoom grid coordinates
    # we need to convert these to mem coordinates
    xyz_ref = Zoom2Ref(xyz_zoom, lrt, Z2, Y2, X2, additive_pad=additive_pad)
    xyz_mem = Ref2Mem(xyz_ref, Z, Y, X)

    zoom = utils_samp.sample3D(mem, xyz_mem, Z2, Y2, X2)
    zoom = torch.reshape(zoom, [B, C, Z2, Y2, X2])
    return zoom
def get_synth_flow(occs,
                   unps,
                   summ_writer,
                   sometimes_zero=False,
                   do_vis=False):
    B, S, C, Z, Y, X = list(occs.shape)
    assert (S == 2, C == 1)

    # we do not sample any rotations here, to keep the distribution purely
    # uniform across all translations
    # (rotation ruins this, since the pivot point is at the camera)
    cam1_T_cam0 = [
        utils_geom.get_random_rt(B, r_amount=0.0,
                                 t_amount=1.0),  # large motion
        utils_geom.get_random_rt(
            B,
            r_amount=0.0,
            t_amount=0.1,  # small motion
            sometimes_zero=sometimes_zero)
    ]
    cam1_T_cam0 = random.sample(cam1_T_cam0, k=1)[0]

    occ0 = occs[:, 0]
    unp0 = unps[:, 0]
    occ1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, occ0, binary_feat=True)
    unp1 = utils_vox.apply_4x4_to_vox(cam1_T_cam0, unp0)
    occs = [occ0, occ1]
    unps = [unp0, unp1]

    if do_vis:
        summ_writer.summ_occs('synth/occs', occs)
        summ_writer.summ_unps('synth/unps', unps, occs)

    mem_T_cam = utils_vox.get_mem_T_ref(B, Z, Y, X)
    cam_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_cam, cam1_T_cam0, cam_T_mem)
    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)
    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)
    flow = xyz_mem1 - xyz_mem0
    # this is B x Z x Y x X x 3
    flow = flow.permute(0, 4, 1, 2, 3)
    # this is B x 3 x Z x Y x X
    if do_vis:
        summ_writer.summ_3D_flow('synth/flow', flow, clip=2.0)

    if do_vis:
        occ0_e = utils_samp.backwarp_using_3D_flow(occ1,
                                                   flow,
                                                   binary_feat=True)
        unp0_e = utils_samp.backwarp_using_3D_flow(unp1, flow)
        summ_writer.summ_occs('synth/occs_stab', [occ0, occ0_e])
        summ_writer.summ_unps('synth/unps_stab', [unp0, unp0_e],
                              [occ0, occ0_e])

    occs = torch.stack(occs, dim=1)
    unps = torch.stack(unps, dim=1)

    return occs, unps, flow, cam1_T_cam0
Beispiel #3
0
def compute_mem1_T_mem0_from_object_flow(flow_mem, mask_mem, occ_mem):
    B, C, Z, Y, X = list(flow_mem.shape)
    assert(C==3)
    mem1_T_mem0 = utils_geom.eye_4x4(B)

    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X, norm=False)
    
    for b in list(range(B)):
        # i think there is a way to parallelize the where/gather but it is beyond me right now
        occ = occ_mem[b]
        mask = mask_mem[b]
        flow = flow_mem[b]
        xyz0 = xyz_mem0[b]
        # cam_T_obj = camR_T_obj[b]
        # mem_T_cam = mem_T_ref[b]

        flow = flow.reshape(3, -1).permute(1, 0)
        # flow is -1 x 3
        inds = torch.where((occ*mask).reshape(-1) > 0.5)
        # inds is ?
        flow = flow[inds]

        xyz0 = xyz0[inds]
        xyz1 = xyz0 + flow

        mem1_T_mem0_ = rigid_transform_3D(xyz0, xyz1)
        # this is 4 x 4 
        mem1_T_mem0[b] = mem1_T_mem0_

    return mem1_T_mem0
Beispiel #4
0
    def forward(self,
                template_feat,
                search_feat,
                template_mask,
                template_lrt,
                search_lrt,
                vox_util,
                lrt_cam0s,
                summ_writer=None):
        # template_feat is the thing we are searching for; it is B x C x ZZ x ZY x ZX
        # search_feat is the featuremap where we are searching; it is B x C x Z x Y x X
        total_loss = torch.tensor(0.0).cuda()

        B, C, ZZ, ZY, ZX = list(template_feat.shape)
        _, _, Z, Y, X = list(search_feat.shape)

        xyz0_template = utils_basic.gridcloud3D(B, ZZ, ZY, ZX)
        # this is B x med x 3
        xyz0_cam = vox_util.Zoom2Ref(xyz0_template, template_lrt, ZZ, ZY, ZX)
        # ok, next, after i relocate the object in search coords,
        # i need to transform those coords into cam, and then do svd on that

        # print('template_feat', template_feat.shape)
        # print('search_feat', search_feat.shape)

        search_feat = search_feat.view(B, C, -1)
        # this is B x C x huge
        template_feat = template_feat.view(B, C, -1)
        # this is B x C x med
        template_mask = template_mask.view(B, -1)
        # this is B x med

        # next i need to sample
        # i would like to take N random samples within the mask

        cam1_T_cam0_e = utils_geom.eye_4x4(B)

        # to simplify the impl, we will iterate over the batch dim
        for b in list(range(B)):
            template_feat_b = template_feat[b]
            template_mask_b = template_mask[b]
            search_feat_b = search_feat[b]
            xyz0_cam_b = xyz0_cam[b]

            # print('xyz0_cam_b', xyz0_cam_b.shape)
            # print('template_mask_b', template_mask_b.shape)
            # print('template_mask_b sum', torch.sum(template_mask_b).cpu().numpy())

            # take any points within the mask
            inds = torch.where(template_mask_b > 0)

            # gather up
            template_feat_b = template_feat_b.permute(1, 0)
            # this is C x med
            template_feat_b = template_feat_b[inds]
            xyz0_cam_b = xyz0_cam_b[inds]
            # these are self.num_pts x C

            # print('inds', inds)
            # not sure why this is a tuple
            # inds = inds[0]

            # trim down to self.num_pts
            # inds = inds.squeeze()
            assert (len(xyz0_cam_b) > 8)  # otw we should have returned early

            # i want to have self.num_pts pts every time
            if len(xyz0_cam_b) < self.num_pts:
                reps = int(self.num_pts / len(xyz0_cam_b)) + 1
                print('only have %d pts; repeating %d times...' %
                      (len(xyz0_cam_b), reps))
                xyz0_cam_b = xyz0_cam_b.repeat(reps, 1)
                template_feat_b = template_feat_b.repeat(reps, 1)
            assert (len(xyz0_cam_b) >= self.num_pts)
            # now trim down
            perm = np.random.permutation(len(xyz0_cam_b))
            # print('perm', perm[:10])
            xyz0_cam_b = xyz0_cam_b[perm[:self.num_pts]]
            template_feat_b = template_feat_b[perm[:self.num_pts]]

            heat_b = torch.matmul(template_feat_b, search_feat_b)
            # this is self.num_pts x huge
            # it represents each point's heatmap in the search region

            # make the min zero
            heat_b = heat_b - (torch.min(heat_b, dim=1).values).unsqueeze(1)
            # scale up, for numerical stability
            heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

            heat_b = heat_b.reshape(self.num_pts, 1, Z, Y, X)
            xyz1_search_b = utils_basic.argmax3D(heat_b,
                                                 hard=False,
                                                 stack=True)
            # this is self.num_pts x 3

            # i need to get to cam coords
            xyz1_cam_b = vox_util.Zoom2Ref(xyz1_search_b.unsqueeze(0),
                                           search_lrt[b:b + 1], Z, Y,
                                           X).squeeze(0)

            # print('xyz0, xyz1', xyz0_cam_b.shape, xyz1_cam_b.shape)
            # cam1_T_cam0_e[b] = utils_track.rigid_transform_3D(xyz0_cam_b, xyz1_cam_b)

            # cam1_T_cam0_e[b] = utils_track.differentiable_rigid_transform_3D(xyz0_cam_b, xyz1_cam_b)
            cam1_T_cam0_e[b] = utils_track.rigid_transform_3D(
                xyz0_cam_b, xyz1_cam_b)

        _, rt_cam0_g = utils_geom.split_lrt(lrt_cam0s[:, 0])
        _, rt_cam1_g = utils_geom.split_lrt(lrt_cam0s[:, 1])
        # these represent ref_T_obj
        cam1_T_cam0_g = torch.matmul(rt_cam1_g, rt_cam0_g.inverse())

        # cam1_T_cam0_e = cam1_T_cam0_g
        lrt_cam1_e = utils_geom.apply_4x4_to_lrtlist(cam1_T_cam0_e,
                                                     lrt_cam0s[:,
                                                               0:1]).squeeze(1)
        # lrt_cam1_g = lrt_cam0s[:,1]

        # _, rt_cam1_e = utils_geom.split_lrt(lrt_cam1_e)
        # _, rt_cam1_g = utils_geom.split_lrt(lrt_cam1_g)

        # let's try the cube loss
        lx, ly, lz = 1.0, 1.0, 1.0
        x = np.array([
            lx / 2., lx / 2., -lx / 2., -lx / 2., lx / 2., lx / 2., -lx / 2.,
            -lx / 2.
        ])
        y = np.array([
            ly / 2., ly / 2., ly / 2., ly / 2., -ly / 2., -ly / 2., -ly / 2.,
            -ly / 2.
        ])
        z = np.array([
            lz / 2., -lz / 2., -lz / 2., lz / 2., lz / 2., -lz / 2., -lz / 2.,
            lz / 2.
        ])
        xyz = np.stack([x, y, z], axis=1)
        # this is 8 x 3
        xyz = torch.from_numpy(xyz).float().cuda()
        xyz = xyz.reshape(1, 8, 3)
        # this is B x 8 x 3

        # xyz_e = utils_geom.apply_4x4(rt_cam1_e, xyz)
        # xyz_g = utils_geom.apply_4x4(rt_cam1_g, xyz)
        xyz_e = utils_geom.apply_4x4(cam1_T_cam0_e, xyz)
        xyz_g = utils_geom.apply_4x4(cam1_T_cam0_g, xyz)

        # print('xyz_e', xyz_e.detach().cpu().numpy())
        # print('xyz_g', xyz_g.detach().cpu().numpy())

        corner_loss = self.smoothl1(xyz_e, xyz_g)
        total_loss = utils_misc.add_loss('robust/corner_loss', total_loss,
                                         corner_loss, hyp.robust_corner_coeff,
                                         summ_writer)

        # rot_e, t_e = utils_geom.split_rt(rt_cam1_e)
        # rot_g, t_g = utils_geom.split_rt(rt_cam1_g)
        rot_e, t_e = utils_geom.split_rt(cam1_T_cam0_e)
        rot_g, t_g = utils_geom.split_rt(cam1_T_cam0_g)
        rx_e, ry_e, rz_e = utils_geom.rotm2eul(rot_e)
        rx_g, ry_g, rz_g = utils_geom.rotm2eul(rot_g)

        rad_e = torch.stack([rx_e, ry_e, rz_e], dim=1)
        rad_g = torch.stack([rx_g, ry_g, rz_g], dim=1)
        deg_e = utils_geom.rad2deg(rad_e)
        deg_g = utils_geom.rad2deg(rad_g)

        r_loss = self.smoothl1(deg_e, deg_g)
        t_loss = self.smoothl1(t_e, t_g)

        total_loss = utils_misc.add_loss('robust/r_loss', total_loss, r_loss,
                                         hyp.robust_r_coeff, summ_writer)
        total_loss = utils_misc.add_loss('robust/t_loss', total_loss, t_loss,
                                         hyp.robust_t_coeff, summ_writer)
        # print('r_loss', r_loss.detach().cpu().numpy())
        # print('t_loss', t_loss.detach().cpu().numpy())

        return lrt_cam1_e, total_loss
Beispiel #5
0
    def forward(self,
                feat_cam0,
                feat_cam1,
                mask_mem0,
                pix_T_cam0,
                pix_T_cam1,
                cam1_T_cam0,
                vox_util,
                summ_writer=None):
        total_loss = torch.tensor(0.0).cuda()

        B, C, Z, Y, X = list(mask_mem0.shape)
        assert (C == 1)

        B2, C, H, W = list(feat_cam0.shape)
        assert (B == B2)

        go_slow = True
        go_slow = False
        if go_slow:
            xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
            mask_mem0 = mask_mem0.reshape(B, Z * Y * X)
            vec0_list = []
            vec1_list = []
            for b in list(range(B)):
                xyz_mem0_b = xyz_mem0[b]
                mask_mem0_b = mask_mem0[b]
                xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)]
                # this is N x 3

                N, D = list(xyz_mem0_b.shape)
                if N > self.num_samples:
                    # to not waste time, i will subsample right here
                    perm = np.random.permutation(N)
                    xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]]
                    # this is num_samples x 3 (smaller than before)

                xyz_cam0_b = vox_util.Mem2Ref(xyz_mem0_b.unsqueeze(0), Z, Y, X)
                xyz_cam1_b = utils_geom.apply_4x4(cam1_T_cam0[b:b + 1],
                                                  xyz_cam0_b)
                # these are N x 3
                # now, i need to project both of these, and sample from the feats

                xy_cam0_b = utils_geom.apply_pix_T_cam(pix_T_cam0[b:b + 1],
                                                       xyz_cam0_b).squeeze(0)
                xy_cam1_b = utils_geom.apply_pix_T_cam(pix_T_cam1[b:b + 1],
                                                       xyz_cam1_b).squeeze(0)
                # these are N x 2

                vec0 = utils_samp.bilinear_sample_single(
                    feat_cam0[b], xy_cam0_b[:, 0], xy_cam0_b[:, 1])
                vec1 = utils_samp.bilinear_sample_single(
                    feat_cam1[b], xy_cam1_b[:, 0], xy_cam1_b[:, 1])
                # these are C x N

                x_pix0 = xy_cam0_b[:, 0]
                y_pix0 = xy_cam0_b[:, 1]
                x_pix1 = xy_cam1_b[:, 0]
                y_pix1 = xy_cam1_b[:, 1]
                y_pix0, x_pix0 = utils_basic.normalize_grid2D(
                    y_pix0, x_pix0, H, W)
                y_pix1, x_pix1 = utils_basic.normalize_grid2D(
                    y_pix1, x_pix1, H, W)
                xy_pix0 = torch.stack([x_pix0, y_pix0], axis=1).unsqueeze(0)
                xy_pix1 = torch.stack([x_pix1, y_pix1], axis=1).unsqueeze(0)
                # these are 1 x N x 2
                print('xy_pix0', xy_pix0.shape)

                vec0 = F.grid_sample(feat_cam0[b:b + 1], xy_pix0)
                vec1 = F.grid_sample(feat_cam1[b:b + 1], xy_pix1)
                print('vec0', vec0.shape)

                vec0_list.append(vec0)
                vec1_list.append(vec1)

            vec0 = torch.cat(vec0_list, dim=1).permute(1, 0)
            vec1 = torch.cat(vec1_list, dim=1).permute(1, 0)
        else:
            xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
            mask_mem0 = mask_mem0.reshape(B, Z * Y * X)

            valid_batches = 0
            sampling_coords_mem0 = torch.zeros(B, self.num_samples,
                                               3).float().cuda()
            valid_feat_cam0 = torch.zeros_like(feat_cam0)
            valid_feat_cam1 = torch.zeros_like(feat_cam1)
            valid_pix_T_cam0 = torch.zeros_like(pix_T_cam0)
            valid_pix_T_cam1 = torch.zeros_like(pix_T_cam1)
            valid_cam1_T_cam0 = torch.zeros_like(cam1_T_cam0)

            # sampling_coords_mem1 = torch.zeros(B, self.num_samples, 3).float().cuda()
            for b in list(range(B)):
                xyz_mem0_b = xyz_mem0[b]
                mask_mem0_b = mask_mem0[b]
                xyz_mem0_b = xyz_mem0_b[torch.where(mask_mem0_b > 0)]
                # this is N x 3

                N, D = list(xyz_mem0_b.shape)
                if N >= self.num_samples:
                    perm = np.random.permutation(N)
                    xyz_mem0_b = xyz_mem0_b[perm[:self.num_samples]]
                    # this is num_samples x 3 (smaller than before)

                    valid_batches += 1
                    # sampling_coords_mem0[valid_batches] = xyz_mem0_b

                    sampling_coords_mem0[b] = xyz_mem0_b
                    valid_feat_cam0[b] = feat_cam0[b]
                    valid_feat_cam1[b] = feat_cam1[b]
                    valid_pix_T_cam0[b] = pix_T_cam0[b]
                    valid_pix_T_cam1[b] = pix_T_cam1[b]
                    valid_cam1_T_cam0[b] = cam1_T_cam0[b]

            print('valid_batches:', valid_batches)
            if valid_batches == 0:
                # return early
                return total_loss

            # trim down
            sampling_coords_mem0 = sampling_coords_mem0[:valid_batches]
            feat_cam0 = valid_feat_cam0[:valid_batches]
            feat_cam1 = valid_feat_cam1[:valid_batches]
            pix_T_cam0 = valid_pix_T_cam0[:valid_batches]
            pix_T_cam1 = valid_pix_T_cam1[:valid_batches]
            cam1_T_cam0 = valid_cam1_T_cam0[:valid_batches]

            xyz_cam0 = vox_util.Mem2Ref(sampling_coords_mem0, Z, Y, X)
            xyz_cam1 = utils_geom.apply_4x4(cam1_T_cam0, xyz_cam0)
            # these are B x N x 3
            # now, i need to project both of these, and sample from the feats

            xy_cam0 = utils_geom.apply_pix_T_cam(pix_T_cam0, xyz_cam0)
            xy_cam1 = utils_geom.apply_pix_T_cam(pix_T_cam1, xyz_cam1)
            # these are B x N x 2

            vec0 = utils_samp.bilinear_sample2D(feat_cam0, xy_cam0[:, :, 0],
                                                xy_cam0[:, :, 1])
            vec1 = utils_samp.bilinear_sample2D(feat_cam1, xy_cam1[:, :, 0],
                                                xy_cam1[:, :, 1])
            # these are B x C x N

            vec0 = vec0.permute(0, 2, 1).view(valid_batches * self.num_samples,
                                              C)
            vec1 = vec1.permute(0, 2, 1).view(valid_batches * self.num_samples,
                                              C)

        print('vec0', vec0.shape)
        print('vec1', vec1.shape)
        # these are N x C

        # # where g is valid, we use it as reference and pull up e
        # margin_loss = self.compute_margin_loss(B, C, D, H, W, emb_e_vec, emb_g_vec.detach(), vis_g_vec, 'g', True, summ_writer)
        # l2_loss = reduce_masked_mean(sql2_on_axis(emb_e-emb_g.detach(), 1, keepdim=True), vis_g)
        # total_loss = utils_misc.add_loss('emb3D/emb_3D_ml_loss', total_loss, margin_loss, hyp.emb_3D_ml_coeff, summ_writer)
        # total_loss = utils_misc.add_loss('emb3D/emb_3D_l2_loss', total_loss, l2_loss, hyp.emb_3D_l2_coeff, summ_writer)

        ce_loss = self.compute_ce_loss(vec0, vec1.detach())
        total_loss = utils_misc.add_loss('tri2D/emb_ce_loss', total_loss,
                                         ce_loss, hyp.tri_2D_ce_coeff,
                                         summ_writer)

        # l2_loss_im = torch.mean(sql2_on_axis(emb_e-emb_g, 1, keepdim=True), dim=3)
        # if summ_writer is not None:
        #     summ_writer.summ_oned('emb3D/emb_3D_l2_loss', l2_loss_im)
        #     summ_writer.summ_feats('emb3D/embs_3D', [emb_e, emb_g], pca=True)
        return total_loss
Beispiel #6
0
def assemble(bkg_feat0, obj_feat0, origin_T_camRs, camRs_T_zoom):
    # let's first assemble the seq of background tensors
    # this should effectively CREATE egomotion
    # i fully expect we can do this all in one shot

    # note it makes sense to create egomotion here, because
    # we want to predict each view

    B, C, Z, Y, X = list(bkg_feat0.shape)
    B2, C2, Z2, Y2, X2 = list(obj_feat0.shape)
    assert (B == B2)
    assert (C == C2)

    B, S, _, _ = list(origin_T_camRs.shape)
    # ok, we have everything we need
    # for each timestep, we want to warp the bkg to this timestep

    # utils for packing/unpacking along seq dim
    __p = lambda x: pack_seqdim(x, B)
    __u = lambda x: unpack_seqdim(x, B)

    # we in fact have utils for this already
    cam0s_T_camRs = utils_geom.get_camM_T_camXs(origin_T_camRs, ind=0)
    camRs_T_cam0s = __u(utils_geom.safe_inverse(__p(cam0s_T_camRs)))

    bkg_feat0s = bkg_feat0.unsqueeze(1).repeat(1, S, 1, 1, 1, 1)
    bkg_featRs = apply_4x4s_to_voxs(camRs_T_cam0s, bkg_feat0s)

    # now for the objects

    # we want to sample for each location in the bird grid
    xyz_mems_ = utils_basic.gridcloud3D(B * S, Z, Y, X, norm=False)
    # this is B*S x Z*Y*X x 3
    xyz_camRs_ = Mem2Ref(xyz_mems_, Z, Y, X)
    camRs_T_zoom_ = __p(camRs_T_zoom)
    zoom_T_camRs_ = camRs_T_zoom_.inverse(
    )  # note this is not a rigid transform
    xyz_zooms_ = utils_geom.apply_4x4(zoom_T_camRs_, xyz_camRs_)

    # we will do the whole traj at once (per obj)
    # note we just have one feat for the whole traj, so we tile up
    obj_feats = obj_feat0.unsqueeze(1).repeat(1, S, 1, 1, 1, 1)
    obj_feats_ = __p(obj_feats)
    # this is B*S x Z x Y x X x C

    # to sample, we need feats_ in ZYX order
    obj_featRs_ = utils_samp.sample3D(obj_feats_, xyz_zooms_, Z, Y, X)
    obj_featRs = __u(obj_featRs_)

    # overweigh objects, so that we essentially overwrite
    # featRs = 0.05*bkg_featRs + 0.95*obj_featRs

    # overwrite the bkg at the object
    obj_mask = (bkg_featRs > 0).float()
    featRs = obj_featRs + (1.0 - obj_mask) * bkg_featRs

    # note the normalization (next) will restore magnitudes for the bkg

    # # featRs = bkg_featRs
    # featRs = obj_featRs

    # l2 normalize on chans
    featRs = l2_normalize(featRs, dim=2)

    validRs = 1.0 - (featRs == 0).all(dim=2, keepdim=True).float().cuda()

    return featRs, validRs, bkg_featRs, obj_featRs
Beispiel #7
0
    def run_test(self, feed):
        results = dict()

        global_step = feed['global_step']
        total_loss = torch.tensor(0.0).cuda()

        __p = lambda x: utils_basic.pack_seqdim(x, self.B)
        __u = lambda x: utils_basic.unpack_seqdim(x, self.B)

        self.obj_clist_camX0 = utils_geom.get_clist_from_lrtlist(
            self.lrt_camX0s)

        self.original_centroid = self.scene_centroid.clone()

        obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(self.lrt_camX0s)
        obj_length = obj_lengths[:, 0]
        for b in list(range(self.B)):
            if self.score_s[b, 0] < 1.0:
                # we need the template to exist
                print('returning early, since score_s[%d,0] = %.1f' %
                      (b, self.score_s[b, 0].cpu().numpy()))
                return total_loss, results, True
            # if torch.sum(self.score_s[b]) < (self.S/2):
            if not (torch.sum(self.score_s[b]) == self.S):
                # the full traj should be valid
                print(
                    'returning early, since sum(score_s) = %d, while S = %d' %
                    (torch.sum(self.score_s).cpu().numpy(), self.S))
                return total_loss, results, True

        if hyp.do_feat3D:

            feat_memX0_input = torch.cat([
                self.occ_memX0s[:, 0],
                self.unp_memX0s[:, 0] * self.occ_memX0s[:, 0],
            ],
                                         dim=1)
            _, feat_memX0, valid_memX0 = self.featnet3D(feat_memX0_input)
            B, C, Z, Y, X = list(feat_memX0.shape)
            S = self.S

            obj_mask_memX0s = self.vox_util.assemble_padded_obj_masklist(
                self.lrt_camX0s, self.score_s, Z, Y, X).squeeze(1)
            # only take the occupied voxels
            occ_memX0 = self.vox_util.voxelize_xyz(self.xyz_camX0s[:, 0], Z, Y,
                                                   X)
            # obj_mask_memX0 = obj_mask_memX0s[:,0] * occ_memX0
            obj_mask_memX0 = obj_mask_memX0s[:, 0]

            # discard the known freespace
            _, free_memX0_, _, _ = self.vox_util.prep_occs_supervision(
                self.camX0s_T_camXs[:, 0:1],
                self.xyz_camXs[:, 0:1],
                Z,
                Y,
                X,
                agg=True)
            free_memX0 = free_memX0_.squeeze(1)
            obj_mask_memX0 = obj_mask_memX0 * (1.0 - free_memX0)

            for b in list(range(self.B)):
                if torch.sum(obj_mask_memX0[b] * occ_memX0[b]) <= 8:
                    print(
                        'returning early, since there are not enough valid object points'
                    )
                    return total_loss, results, True

            # for b in list(range(self.B)):
            #     sum_b = torch.sum(obj_mask_memX0[b])
            #     print('sum_b', sum_b.detach().cpu().numpy())
            #     if sum_b > 1000:
            #         obj_mask_memX0[b] *= occ_memX0[b]
            #         sum_b = torch.sum(obj_mask_memX0[b])
            #         print('reducing this to', sum_b.detach().cpu().numpy())

            feat0_vec = feat_memX0.view(B, hyp.feat3D_dim, -1)
            # this is B x C x huge
            feat0_vec = feat0_vec.permute(0, 2, 1)
            # this is B x huge x C

            obj_mask0_vec = obj_mask_memX0.reshape(B, -1).round()
            occ_mask0_vec = occ_memX0.reshape(B, -1).round()
            free_mask0_vec = free_memX0.reshape(B, -1).round()
            # these are B x huge

            orig_xyz = utils_basic.gridcloud3D(B, Z, Y, X)
            # this is B x huge x 3

            obj_lengths, cams_T_obj0 = utils_geom.split_lrtlist(
                self.lrt_camX0s)
            obj_length = obj_lengths[:, 0]
            cam0_T_obj = cams_T_obj0[:, 0]
            # this is B x S x 4 x 4

            mem_T_cam = self.vox_util.get_mem_T_ref(B, Z, Y, X)
            cam_T_mem = self.vox_util.get_ref_T_mem(B, Z, Y, X)

            lrt_camIs_g = self.lrt_camX0s.clone()
            lrt_camIs_e = torch.zeros_like(self.lrt_camX0s)
            # we will fill this up

            ious = torch.zeros([B, S]).float().cuda()
            point_counts = np.zeros([B, S])
            inb_counts = np.zeros([B, S])

            feat_vis = []
            occ_vis = []

            for s in range(self.S):
                if not (s == 0):
                    # remake the vox util and all the mem data
                    self.scene_centroid = utils_geom.get_clist_from_lrtlist(
                        lrt_camIs_e[:, s - 1:s])[:, 0]
                    delta = self.scene_centroid - self.original_centroid
                    self.vox_util = vox_util.Vox_util(
                        self.Z,
                        self.Y,
                        self.X,
                        self.set_name,
                        scene_centroid=self.scene_centroid,
                        assert_cube=True)
                    self.occ_memXs = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camXs), self.Z,
                                                   self.Y, self.X))
                    self.occ_memX0s = __u(
                        self.vox_util.voxelize_xyz(__p(self.xyz_camX0s),
                                                   self.Z, self.Y, self.X))

                    self.unp_memXs = __u(
                        self.vox_util.unproject_rgb_to_mem(
                            __p(self.rgb_camXs), self.Z, self.Y, self.X,
                            __p(self.pix_T_cams)))
                    self.unp_memX0s = self.vox_util.apply_4x4s_to_voxs(
                        self.camX0s_T_camXs, self.unp_memXs)
                    self.summ_writer.summ_occ('track/reloc_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                else:
                    self.summ_writer.summ_occ('track/init_occ_%d' % s,
                                              self.occ_memX0s[:, s])
                    delta = torch.zeros([B, 3]).float().cuda()
                # print('scene centroid:', self.scene_centroid.detach().cpu().numpy())
                occ_vis.append(
                    self.summ_writer.summ_occ('',
                                              self.occ_memX0s[:, s],
                                              only_return=True))

                # inb = __u(self.vox_util.get_inbounds(__p(self.xyz_camX0s), self.Z4, self.Y4, self.X, already_mem=False))
                inb = self.vox_util.get_inbounds(self.xyz_camX0s[:, s],
                                                 self.Z4,
                                                 self.Y4,
                                                 self.X,
                                                 already_mem=False)
                num_inb = torch.sum(inb.float(), axis=1)
                # print('num_inb', num_inb, num_inb.shape)
                inb_counts[:, s] = num_inb.cpu().numpy()

                feat_memI_input = torch.cat([
                    self.occ_memX0s[:, s],
                    self.unp_memX0s[:, s] * self.occ_memX0s[:, s],
                ],
                                            dim=1)
                _, feat_memI, valid_memI = self.featnet3D(feat_memI_input)

                self.summ_writer.summ_feat('3D_feats/feat_%d_input' % s,
                                           feat_memI_input,
                                           pca=True)
                self.summ_writer.summ_feat('3D_feats/feat_%d' % s,
                                           feat_memI,
                                           pca=True)
                feat_vis.append(
                    self.summ_writer.summ_feat('',
                                               feat_memI,
                                               pca=True,
                                               only_return=True))

                # collect freespace here, to discard bad matches
                _, free_memI_, _, _ = self.vox_util.prep_occs_supervision(
                    self.camX0s_T_camXs[:, s:s + 1],
                    self.xyz_camXs[:, s:s + 1],
                    Z,
                    Y,
                    X,
                    agg=True)
                free_memI = free_memI_.squeeze(1)

                feat_vec = feat_memI.view(B, hyp.feat3D_dim, -1)
                # this is B x C x huge
                feat_vec = feat_vec.permute(0, 2, 1)
                # this is B x huge x C

                memI_T_mem0 = utils_geom.eye_4x4(B)
                # we will fill this up

                # # put these on cpu, to save mem
                # feat0_vec = feat0_vec.detach().cpu()
                # feat_vec = feat_vec.detach().cpu()

                # to simplify the impl, we will iterate over the batch dim
                for b in list(range(B)):
                    feat_vec_b = feat_vec[b]
                    feat0_vec_b = feat0_vec[b]
                    obj_mask0_vec_b = obj_mask0_vec[b]
                    occ_mask0_vec_b = occ_mask0_vec[b]
                    free_mask0_vec_b = free_mask0_vec[b]
                    orig_xyz_b = orig_xyz[b]
                    # these are huge x C

                    careful = False
                    if careful:
                        # start with occ points, since these are definitely observed
                        obj_inds_b = torch.where(
                            (occ_mask0_vec_b * obj_mask0_vec_b) > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # also take random non-free non-occ points in the mask
                        ok_mask = obj_mask0_vec_b * (1.0 - occ_mask0_vec_b) * (
                            1.0 - free_mask0_vec_b)
                        alt_inds_b = torch.where(ok_mask > 0)
                        alt_vec_b = feat0_vec_b[alt_inds_b]
                        alt_xyz0 = orig_xyz_b[alt_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        num = len(alt_xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            # print('have %d pts; taking a random set of %d pts inside' % (num, max_pts))
                            perm = np.random.permutation(num)
                            alt_vec_b = alt_vec_b[perm[:max_pts]]
                            alt_xyz0 = alt_xyz0[perm[:max_pts]]

                        obj_vec_b = torch.cat([obj_vec_b, alt_vec_b], dim=0)
                        xyz0 = torch.cat([xyz0, alt_xyz0], dim=0)
                        if s == 0:
                            print('have %d pts in total' % (len(xyz0)))
                    else:
                        # take any points within the mask
                        obj_inds_b = torch.where(obj_mask0_vec_b > 0)
                        obj_vec_b = feat0_vec_b[obj_inds_b]
                        xyz0 = orig_xyz_b[obj_inds_b]
                        # these are med x C

                        # issues arise when "med" is too large
                        # trim down to max_pts
                        num = len(xyz0)
                        max_pts = 2000
                        if num > max_pts:
                            print(
                                'have %d pts; taking a random set of %d pts inside'
                                % (num, max_pts))
                            perm = np.random.permutation(num)
                            obj_vec_b = obj_vec_b[perm[:max_pts]]
                            xyz0 = xyz0[perm[:max_pts]]

                    obj_vec_b = obj_vec_b.permute(1, 0)
                    # this is is C x med

                    corr_b = torch.matmul(feat_vec_b, obj_vec_b)
                    # this is huge x med

                    heat_b = corr_b.permute(1, 0).reshape(-1, 1, Z, Y, X)
                    # this is med x 1 x Z4 x Y4 x X4

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    # # for numerical stability, we sub the max, and mult by the resolution
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b = heat_b - heat_b_max
                    # heat_b = heat_b * float(len(heat_b[0].reshape(-1)))
                    # heat_b_ = heat_b.reshape(-1, Z*Y*X)
                    # # heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(-1, 1, 1, 1, 1)
                    # heat_b_min = (torch.min(heat_b_).values)
                    # free_b = free_memI[b:b+1]
                    # print('free_b', free_b.shape)
                    # print('heat_b', heat_b.shape)
                    # heat_b[free_b > 0.0] = heat_b_min

                    # make the min zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_min = (torch.min(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_min
                    # zero out the freespace
                    heat_b = heat_b * (1.0 - free_memI[b:b + 1])
                    # make the max zero
                    heat_b_ = heat_b.reshape(-1, Z * Y * X)
                    heat_b_max = (torch.max(heat_b_, dim=1).values).reshape(
                        -1, 1, 1, 1, 1)
                    heat_b = heat_b - heat_b_max
                    # scale up, for numerical stability
                    heat_b = heat_b * float(len(heat_b[0].reshape(-1)))

                    xyzI = utils_basic.argmax3D(heat_b, hard=False, stack=True)
                    # xyzI = utils_basic.argmax3D(heat_b*float(Z*10), hard=False, stack=True)
                    # this is med x 3

                    xyzI_cam = self.vox_util.Mem2Ref(xyzI.unsqueeze(1), Z, Y,
                                                     X)
                    xyzI_cam += delta
                    xyzI = self.vox_util.Ref2Mem(xyzI_cam, Z, Y, X).squeeze(1)

                    memI_T_mem0[b] = utils_track.rigid_transform_3D(xyz0, xyzI)

                    # record #points, since ransac depends on this
                    point_counts[b, s] = len(xyz0)
                # done stepping through batch

                mem0_T_memI = utils_geom.safe_inverse(memI_T_mem0)
                cam0_T_camI = utils_basic.matmul3(cam_T_mem, mem0_T_memI,
                                                  mem_T_cam)

                # eval
                camI_T_obj = utils_basic.matmul4(cam_T_mem, memI_T_mem0,
                                                 mem_T_cam, cam0_T_obj)
                # this is B x 4 x 4
                lrt_camIs_e[:,
                            s] = utils_geom.merge_lrt(obj_length, camI_T_obj)
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    lrt_camIs_e[:, s:s + 1], lrt_camIs_g[:,
                                                         s:s + 1]).squeeze(1)
            results['ious'] = ious
            # if ious[0,-1] > 0.5:
            #     print('returning early, since acc is too high')
            #     return total_loss, results, True

            self.summ_writer.summ_rgbs('track/feats', feat_vis)
            self.summ_writer.summ_oneds('track/occs', occ_vis, norm=False)

            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())

            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())
            self.summ_writer.summ_scalar('track/point_counts',
                                         np.mean(point_counts))
            # self.summ_writer.summ_scalar('track/inb_counts', torch.mean(inb_counts).cpu().item())
            self.summ_writer.summ_scalar('track/inb_counts',
                                         np.mean(inb_counts))

            lrt_camX0s_e = lrt_camIs_e.clone()
            lrt_camXs_e = utils_geom.apply_4x4s_to_lrts(
                self.camXs_T_camX0s, lrt_camX0s_e)

            if self.include_vis:
                visX_e = []
                for s in list(range(self.S)):
                    visX_e.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_e' % s,
                                                      self.rgb_camXs[:, s],
                                                      lrt_camXs_e[:, s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_e', visX_e)
                visX_g = []
                for s in list(range(self.S)):
                    visX_g.append(
                        self.summ_writer.summ_lrtlist('track/box_camX%d_g' % s,
                                                      self.rgb_camXs[:, s],
                                                      self.lrt_camXs[:,
                                                                     s:s + 1],
                                                      self.score_s[:, s:s + 1],
                                                      self.tid_s[:, s:s + 1],
                                                      self.pix_T_cams[:, 0],
                                                      only_return=True))
                self.summ_writer.summ_rgbs('track/box_camXs_g', visX_g)

            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)

            dists = torch.norm(obj_clist_camX0_e - self.obj_clist_camX0, dim=2)
            # this is B x S
            mean_dist = utils_basic.reduce_masked_mean(dists, self.score_s)
            median_dist = utils_basic.reduce_masked_median(dists, self.score_s)
            # this is []
            self.summ_writer.summ_scalar('track/centroid_dist_mean',
                                         mean_dist.cpu().item())
            self.summ_writer.summ_scalar('track/centroid_dist_median',
                                         median_dist.cpu().item())

            # if self.include_vis:
            if (True):
                self.summ_writer.summ_traj_on_occ('track/traj_e',
                                                  obj_clist_camX0_e,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
                self.summ_writer.summ_traj_on_occ('track/traj_g',
                                                  self.obj_clist_camX0,
                                                  self.occ_memX0s[:, 0],
                                                  self.vox_util,
                                                  already_mem=False,
                                                  sigma=2)
            total_loss += mean_dist  # we won't backprop, but it's nice to plot and print this anyway

        else:
            ious = torch.zeros([self.B, self.S]).float().cuda()
            for s in list(range(self.S)):
                ious[:, s] = utils_geom.get_iou_from_corresponded_lrtlists(
                    self.lrt_camX0s[:, 0:1],
                    self.lrt_camX0s[:, s:s + 1]).squeeze(1)
            results['ious'] = ious
            for s in range(self.S):
                self.summ_writer.summ_scalar(
                    'track/mean_iou_%02d' % s,
                    torch.mean(ious[:, s]).cpu().item())
            self.summ_writer.summ_scalar('track/mean_iou',
                                         torch.mean(ious).cpu().item())

            lrt_camX0s_e = self.lrt_camX0s[:, 0:1].repeat(1, self.S, 1)
            obj_clist_camX0_e = utils_geom.get_clist_from_lrtlist(lrt_camX0s_e)
            self.summ_writer.summ_traj_on_occ('track/traj_e',
                                              obj_clist_camX0_e,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)
            self.summ_writer.summ_traj_on_occ('track/traj_g',
                                              self.obj_clist_camX0,
                                              self.occ_memX0s[:, 0],
                                              self.vox_util,
                                              already_mem=False,
                                              sigma=2)

        self.summ_writer.summ_scalar('loss', total_loss.cpu().item())
        return total_loss, results, False
def get_gt_flow(obj_lrtlist_camRs,
                obj_scorelist,
                camRs_T_camXs,
                Z,
                Y,
                X,
                K=2,
                mod='',
                vis=True,
                summ_writer=None):
    # this constructs the flow field according to the given
    # box trajectories (obj_lrtlist_camRs) (collected from a moving camR)
    # and egomotion (encoded in camRs_T_camXs)
    # (so they do not take into account egomotion)
    # so, we first generate the flow for all the objects,
    # then in the background, put the ego flow
    N, B, S, D = list(obj_lrtlist_camRs.shape)
    assert (S == 2)  # as a flow util, this expects S=2

    flows = []
    masks = []
    for k in list(range(K)):
        obj_masklistR0 = utils_vox.assemble_padded_obj_masklist(
            obj_lrtlist_camRs[k, :, 0:1],
            obj_scorelist[k, :, 0:1],
            Z,
            Y,
            X,
            coeff=1.0)
        # this is B x 1(N) x 1(C) x Z x Y x Z
        # obj_masklistR0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X
        obj_mask0 = obj_masklistR0.squeeze(1)
        # this is B x 1 x Z x Y x X

        camR_T_cam0 = camRs_T_camXs[:, 0]
        camR_T_cam1 = camRs_T_camXs[:, 1]
        cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
        cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)
        # camR0_T_camR1 = camR0_T_camRs[:,1]
        # camR1_T_camR0 = utils_geom.safe_inverse(camR0_T_camR1)

        # obj_masklistA1 = utils_vox.apply_4x4_to_vox(camR1_T_camR0, obj_masklistA0)
        # if vis and (summ_writer is not None):
        #     summ_writer.summ_occ('flow/obj%d_maskA0' % k, obj_masklistA0)
        #     summ_writer.summ_occ('flow/obj%d_maskA1' % k, obj_masklistA1)

        if vis and (summ_writer is not None):
            # summ_writer.summ_occ('flow/obj%d_mask0' % k, obj_mask0)
            summ_writer.summ_oned('flow/obj%d_mask0_%s' % (k, mod),
                                  torch.mean(obj_mask0, 3))

        _, ref_T_objs_list = utils_geom.split_lrtlist(obj_lrtlist_camRs[k])
        # this is B x S x 4 x 4
        ref_T_obj0 = ref_T_objs_list[:, 0]
        ref_T_obj1 = ref_T_objs_list[:, 1]
        obj0_T_ref = utils_geom.safe_inverse(ref_T_obj0)
        obj1_T_ref = utils_geom.safe_inverse(ref_T_obj1)
        # these are B x 4 x 4

        mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
        ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

        ref1_T_ref0 = utils_basic.matmul2(ref_T_obj1, obj0_T_ref)
        cam1_T_cam0 = utils_basic.matmul3(cam1_T_camR, ref1_T_ref0,
                                          camR_T_cam0)
        mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

        xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
        xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

        xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
        xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

        # only use these displaced points within the obj mask
        # obj_mask03 = obj_mask0.view(B, Z, Y, X, 1).repeat(1, 1, 1, 1, 3)
        obj_mask0 = obj_mask0.view(B, Z, Y, X, 1)
        # # xyz_mem1[(obj_mask03 < 1.0).bool()] = xyz_mem0
        # cond = (obj_mask03 < 1.0).float()
        cond = (obj_mask0 > 0.0).float()
        xyz_mem1 = cond * xyz_mem1 + (1.0 - cond) * xyz_mem0

        flow = xyz_mem1 - xyz_mem0
        flow = flow.permute(0, 4, 1, 2, 3)
        obj_mask0 = obj_mask0.permute(0, 4, 1, 2, 3)

        # if vis and k==0:
        if vis:
            summ_writer.summ_3D_flow('flow/gt_%d_%s' % (k, mod),
                                     flow,
                                     clip=4.0)

        masks.append(obj_mask0)
        flows.append(flow)

    camR_T_cam0 = camRs_T_camXs[:, 0]
    camR_T_cam1 = camRs_T_camXs[:, 1]
    cam0_T_camR = utils_geom.safe_inverse(camR_T_cam0)
    cam1_T_camR = utils_geom.safe_inverse(camR_T_cam1)

    mem_T_ref = utils_vox.get_mem_T_ref(B, Z, Y, X)
    ref_T_mem = utils_vox.get_ref_T_mem(B, Z, Y, X)

    cam1_T_cam0 = utils_basic.matmul2(cam1_T_camR, camR_T_cam0)
    mem1_T_mem0 = utils_basic.matmul3(mem_T_ref, cam1_T_cam0, ref_T_mem)

    xyz_mem0 = utils_basic.gridcloud3D(B, Z, Y, X)
    xyz_mem1 = utils_geom.apply_4x4(mem1_T_mem0, xyz_mem0)

    xyz_mem0 = xyz_mem0.reshape(B, Z, Y, X, 3)
    xyz_mem1 = xyz_mem1.reshape(B, Z, Y, X, 3)

    flow = xyz_mem1 - xyz_mem0
    flow = flow.permute(0, 4, 1, 2, 3)

    bkg_flow = flow

    # allow zero motion in the bkg
    any_mask = torch.max(torch.stack(masks, axis=0), axis=0)[0]
    masks.append(1.0 - any_mask)
    flows.append(bkg_flow)

    flows = torch.stack(flows, axis=0)
    masks = torch.stack(masks, axis=0)
    masks = masks.repeat(1, 1, 3, 1, 1, 1)
    flow = utils_basic.reduce_masked_mean(flows, masks, dim=0)

    if vis:
        summ_writer.summ_3D_flow('flow/gt_complete', flow, clip=4.0)

    # flow is shaped B x 3 x D x H x W
    return flow