Exemple #1
0
    def forward(self, data_batch):
        # (batch_size, num_views, 3, h, w)
        images = data_batch['images']
        b, nv, _, h, w = images.size()
        # collapse first 2 dimensions together
        images = images.reshape([-1] + list(images.shape[2:]))

        # 2D network
        preds_2d = self.net_2d({'image': images})
        feature_2d = preds_2d['feature']  # (b * nv, c, h, w)

        # unproject features
        knn_indices = data_batch['knn_indices']  # (b, np, k)
        feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(
            1, 2).contiguous()  # (b, c, nv, h, w)
        feature_2d = feature_2d.reshape(b, -1, nv * h * w)
        feature_2d = group_points(feature_2d, knn_indices)  # (b, c, np, k)

        # unproject depth maps
        with torch.no_grad():
            image_xyz = data_batch['image_xyz']  # (b, nv, h, w, 3)
            image_xyz = image_xyz.permute(0, 4, 1, 2,
                                          3).reshape(b, 3, nv * h * w)
            image_xyz = group_points(image_xyz, knn_indices)  # (b, 3, np, k)

        # 2D-3D aggregation
        points = data_batch['points']
        feature_2d3d = self.feat_aggreg(image_xyz, points, feature_2d)

        # 3D network
        preds_3d = self.net_3d({'points': points, 'feature': feature_2d3d})
        preds = preds_3d
        return preds
Exemple #2
0
def test(b, c, n1, n2, k, profile):
    torch.manual_seed(0)

    feature = torch.randn(b, c, n1).cuda()
    index = torch.randint(0, n1, [b, n2, k]).long().cuda()

    feature_gather = feature.clone()
    feature_gather.requires_grad = True
    feature_cuda = feature.clone()
    feature_cuda.requires_grad = True

    # Check forward
    out_gather = group_points_torch(feature_gather, index)
    out_cuda = group_points(feature_cuda, index)
    assert out_gather.allclose(out_cuda)

    # Check backward
    out_gather.backward(torch.ones_like(out_gather))
    out_cuda.backward(torch.ones_like(out_cuda))
    grad_gather = feature_gather.grad
    grad_cuda = feature_cuda.grad
    assert grad_gather.allclose(grad_cuda)

    if profile:
        with torch.autograd.profiler.profile(
                use_cuda=torch.cuda.is_available()) as prof:
            out_cuda = group_points(feature_cuda, index)
        print(prof)
        with torch.autograd.profiler.profile(
                use_cuda=torch.cuda.is_available()) as prof:
            out_cuda.backward(torch.ones_like(out_cuda))
        print(prof)
Exemple #3
0
def get_2dfeature_imagexyz(data_i, network_2d):
    data = data_i
    points = data['points']  # tensor (3,np)
    # colors = data.get('feature', None)
    images = data['images']  # (nv, 3, h, w) tensor

    b = 1
    nv = 30
    h = 120
    w = 160
    k=3

    images = torch.from_numpy(images).unsqueeze(0)  # (1, nv, 3, h, w) ndarray -> tensor
    # collapse first 2 dimensions together
    images = images.reshape([-1] + list(images.shape[2:]))

    # load freezed 2D network

    net_2d = network_2d

    # net_2d.cuda()  #train 2d network on cpu to save cuda memory
    preds_2d = net_2d({'image': images})
    feature_2d = preds_2d['feature']  # (b * nv, c, h, w)

    # unproject features
    knn_indices = torch.from_numpy(data['knn_indices']).unsqueeze(0)  # (b, np, k)  #因为knn indices里面有np,所以resample点本质是限制knnindecis的数量,从而fix unproject的数量
    feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(1, 2).contiguous()  # (b, c, nv, h, w)
    feature_2d = feature_2d.reshape(b, -1, nv * h * w)
    feature_2d = group_points(feature_2d.cuda(), knn_indices.cuda())  # (b, c, np, k) c=64
    feature_2d_cpu = feature_2d.squeeze().cpu()
    feature_2d_cpu = feature_2d_cpu.permute(1, 2 ,0).reshape(-1, 64*k) #(np,k*64)
    feature_2d_numpy = feature_2d_cpu.numpy()

    # unproject depth maps #(unproject point cloud direct from frames selected)
    with torch.no_grad():
        knn_indices = torch.from_numpy(data['knn_indices']).unsqueeze(0)
        image_xyz = torch.from_numpy(data['image_xyz']).unsqueeze(0)  # (b, nv, h, w, 3)
        image_xyz = image_xyz.permute(0, 4, 1, 2, 3).reshape(b, 3, nv * h * w)
        image_xyz = group_points(image_xyz.cuda(), knn_indices.cuda())  # (b, 3, np, k)
        image_xyz_cpu = image_xyz.squeeze().cpu()
        image_xyz_cpu = image_xyz_cpu.permute(1, 2, 0).reshape(-1, 3*k) #(np, k*3)
        image_xyz_numpy = image_xyz_cpu.numpy()


    d = {'feature_2d' : feature_2d_numpy, 'image_xyz' : image_xyz_numpy}

    return d
Exemple #4
0
    def forward(self, new_xyz, xyz, feature, use_xyz):
        with torch.no_grad():
            index = ball_query(new_xyz, xyz, self.radius, self.max_neighbors)

        # (batch_size, 3, num_centroids, num_neighbors)
        group_xyz = group_points(xyz, index)
        # translation normalization
        group_xyz -= new_xyz.unsqueeze(-1)

        if feature is not None:
            # (batch_size, channels, num_centroids, num_neighbors)
            group_feature = group_points(feature, index)
            if use_xyz:
                group_feature = torch.cat([group_feature, group_xyz], dim=1)
        else:
            group_feature = group_xyz

        return group_feature, group_xyz
Exemple #5
0
    def forward(self, data_batch):
        # (batch_size, num_views, 3, h, w)
        images = data_batch['images']
        b, nv, _, h, w = images.size()
        # collapse first 2 dimensions together
        images = images.reshape([-1] + list(images.shape[2:]))

        # 2D network
        preds_2d = self.net_2d({'image': images})
        seg_logit_2d = preds_2d['seg_logit']  # (b * nv, nc, h, w)
        # feature_2d = preds_2d['feature']  # (b * nv, c, h, w)

        # unproject features
        knn_indices = data_batch['knn_indices']  # (b, np, k)
        seg_logit = seg_logit_2d.reshape(b, nv, -1, h, w).transpose(1, 2).contiguous()  # (b, nc, nv, h, w)
        seg_logit = seg_logit.reshape(b, -1, nv * h * w)
        seg_logit = group_points(seg_logit, knn_indices)  # (b, nc, np, k)
        seg_logit = seg_logit.mean(-1)  # (b, nc, np)

        preds = {
            'seg_logit': seg_logit,
        }
        return preds
Exemple #6
0
    def forward(self, batch, config):

        #------------------------------------------------------------------------------------------------------
        # ------------------------------------------------------------------------------------------------------
        # images = batch.images # (num_views, 3, h, w) tensor
        # images = images.unsqueeze(0) # (1, num_views, 3, h, w)
        # b, nv, _, h, w = images.size()
        # # collapse first 2 dimensions together
        # images = images.reshape([-1] + list(images.shape[2:])) # (b * num_views, 3, h, w)
        #
        # # 2D network
        # preds_2d = self.net_2d({'image': images})
        # feature_2d = preds_2d['feature']  # (b * nv, c, h, w) #requires_grad=false
        #
        # # unproject features
        # knn_indices = batch.knn_indices  # (np, k) float32
        # knn_indices = knn_indices.unsqueeze(0).long() # (1, np, k) float64
        # # feature_2d = feature_2d.permute(1, 0, 2, 3)  # (c, b * nv, h, w)
        # # feature_2d = feature_2d.unsqueeze(0).reshape(1, -1, b * nv * h * w)  #####(1,64,b*nv*h*w)
        # # feature_2d = group_points(feature_2d, knn_indices)  # (b, c, np, k)
        #
        # # knn_indices = knn_indices.reshape(b, -1, 3)
        # feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(1, 2).contiguous()  # (b, c, nv, h, w)
        # feature_2d = feature_2d.reshape(b, -1, nv * h * w)
        # feature_2d = group_points(feature_2d, knn_indices)  # (b, c, np, k)
        #
        #
        # # unproject depth maps #(unproject point cloud direct from frames selected)
        # with torch.no_grad():
        #     image_xyz = batch.image_xyz  # (b, nv, h, w, 3)
        #     image_xyz = image_xyz.unsqueeze(0) # (1, b, nv, h, w, 3)
        #     # image_xyz = image_xyz.permute(0, 5, 1, 2, 3, 4).reshape(1, 3, b * nv * h * w)  ####(1,3,b*nv*h*w)
        #     # image_xyz = group_points(image_xyz, knn_indices)  # (1, 3, np, k)
        #
        #     image_xyz = image_xyz.permute(0, 4, 1, 2, 3).reshape(b, 3, nv * h * w)  ####(b,3,576000) 30*120*160=576000
        #     image_xyz = group_points(image_xyz, knn_indices)  # (b, 3, np, k)
        #
        # # 2D-3D aggregation
        #
        # feat_aggre_points = batch.feat_aggre_points #(np,3)
        # # feat_aggre_points = batch.points[0]  # (np,3)
        # feat_aggre_points = feat_aggre_points.unsqueeze(0).transpose(1,2) #(1,3,np)
        # feature_2d3d = self.feat_aggreg(image_xyz, feat_aggre_points, feature_2d) #(b,64,np)
        # feature_2d3d = feature_2d3d.squeeze(0).T #(np,64)
        # ------------------------------------------------------------------------------------------------------
        # ------------------------------------------------------------------------------------------------------

        # (batch_size, num_views, 3, h, w)
        images = batch.images
        b, nv, k, h, w = images.size()
        # collapse first 2 dimensions together
        images = images.reshape([-1] + list(images.shape[2:]))

        # 2D network
        preds_2d = self.net_2d({'image': images})
        feature_2d = preds_2d[
            'feature']  # (b * nv, c, h, w) requires_grad=false
        # reshape 2d feature
        feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(
            1, 2).contiguous()  # (b, c, nv, h, w)
        feature_2d = feature_2d.reshape(
            b, -1, nv * h * w)  #(b, 64, nv * h * w) # c=64 output channels
        feature_2d_list = []
        # reshape image_xyz
        image_xyz = batch.image_xyz  # (b, nv, h, w, 3)
        image_xyz = image_xyz.permute(0, 4, 1, 2, 3).reshape(
            b, 3, nv * h * w)  #(b, 3, nv * h * w)
        image_xyz_list = []

        knn_list = batch.knn_list
        for i in range(b):
            # unproject 2d feature for each scene
            batch_knn_indices = torch.from_numpy(knn_list[i]).long().cuda(
            )  #(1, s_np, k) #knn_indices of scene_i
            batch_feature_2d = feature_2d[i, :, :].unsqueeze(
                0)  #(1, 64, nv * h * w) # 2d feature of scene_i
            batch_feature_2d = group_points(
                batch_feature_2d, batch_knn_indices
            )  # (1, 64, s_np, k) grouped points for scene_i
            feature_2d_list.append(batch_feature_2d)

            # unproject depth maps for each scene
            with torch.no_grad():
                batch_image_xyz = image_xyz[i, :, :].unsqueeze(
                    0)  #(1, 3, nv * h * w) # image_xyz of scene_i
                batch_image_xyz = group_points(
                    batch_image_xyz, batch_knn_indices
                )  # (1, 3, s_np, k) grouped points for scene_i
                image_xyz_list.append(batch_image_xyz)

        # seen these two as point-wise feature and stacked them in point dims to adapt size of input points
        input_feature_2d = torch.cat(
            feature_2d_list,
            dim=2)  # (1, c, np, k) # 2d feature of one large stacked batch
        input_image_xyz = torch.cat(
            image_xyz_list,
            dim=2)  # (1, 3, np, k) # image_xyz of one large stacked batch

        # 2D-3D aggregation
        points = batch.feat_aggre_points.transpose(
            1, 2)  # (1,3,np) # input points of one large stacked batch
        feature_2d3d = self.feat_aggreg(input_image_xyz, points,
                                        input_feature_2d)  # (1,64,np)
        feature_2d3d = feature_2d3d.permute(0, 2, 1).reshape(-1, 64)  # (np,64)

        # stack the features with constant 1 to insure black/dark points are not ignored
        stacked_features = torch.ones_like(
            batch.feat_aggre_points[:, :, :1].squeeze(0))  # (np,1)
        stacked_features = torch.cat((stacked_features, feature_2d3d),
                                     dim=1)  # (np,65) feature dim = 64+1

        # Get input features
        x = stacked_features.clone().detach()  # feature dim: 64+1
        # x = stacked_features.clone()

        # Loop over consecutive blocks
        skip_x = []
        for block_i, block_op in enumerate(self.encoder_blocks):
            if block_i in self.encoder_skips:
                skip_x.append(x)
            x = block_op(x, batch)

        for block_i, block_op in enumerate(self.decoder_blocks):
            if block_i in self.decoder_concats:
                x = torch.cat([x, skip_x.pop()], dim=1)
            x = block_op(x, batch)

        # Head of network
        # late fusion, fusing twice!->not good idea
        # x = torch.cat((x, feature_2d3d), dim=1)  # (np, 128+64=192)
        x = self.head_mlp(x, batch)
        x = self.head_softmax(x, batch)

        return x
Exemple #7
0
def get_2d3dfeature(data_i, network_2d):
    data = data_i
    points = data['points']  # tensor (3,np)
    # colors = data.get('feature', None)
    images = data['images']  # (nv, 3, h, w) tensor

    b = 1
    nv = 30
    h = 120
    w = 160

    images = torch.from_numpy(images).unsqueeze(0)  # (1, nv, 3, h, w) ndarray -> tensor
    # collapse first 2 dimensions together
    images = images.reshape([-1] + list(images.shape[2:]))

    # load freezed 2D network

    net_2d = network_2d

    # net_2d.cuda()  #train 2d network on cpu to save cuda memory
    preds_2d = net_2d({'image': images})
    feature_2d = preds_2d['feature']  # (b * nv, c, h, w)

    # unproject features
    knn_indices = torch.from_numpy(data['knn_indices']).unsqueeze(0)  # (b, np, k)  #因为knn indices里面有np,所以resample点本质是限制knnindecis的数量,从而fix unproject的数量
    feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(1, 2).contiguous()  # (b, c, nv, h, w)
    feature_2d = feature_2d.reshape(b, -1, nv * h * w)
    feature_2d = group_points(feature_2d.cuda(), knn_indices.cuda())  # (b, c, np, k)

    # unproject depth maps #(unproject point cloud direct from frames selected)
    with torch.no_grad():
        knn_indices = torch.from_numpy(data['knn_indices']).unsqueeze(0)
        image_xyz = torch.from_numpy(data['image_xyz']).unsqueeze(0)  # (b, nv, h, w, 3)
        image_xyz = image_xyz.permute(0, 4, 1, 2, 3).reshape(b, 3, nv * h * w)
        image_xyz = group_points(image_xyz.cuda(), knn_indices.cuda())  # (b, 3, np, k)

    # 2D-3D aggregation
    with torch.no_grad():
        points = torch.from_numpy(points).unsqueeze(0)  # (b,3,np)
        feat_aggreg = FeatureAggregation(64)

    #--------------------------------------------------#
    # load weights from pretrained model
    # --------------------------------------------------#
        load_checkpoint = torch.load(
            '/home/dchangyu/mvpnet/outputs_use/scannet/mvpnet_3d_unet_resnet34_pn2ssg/model_040000.pth')
        load_model_parameter = load_checkpoint['model'].keys()
        feat_aggreg_weights = {}
        for key in load_model_parameter:
            for i in key.split('.'):
                if i == 'feat_aggreg':
                    keys = '.'.join(key.split('.')[1:])
                    feat_aggreg_weights[keys] = load_checkpoint['model'][key]

        feat_aggreg.load_state_dict(feat_aggreg_weights)
        feat_aggreg.eval()
# -----------------------------------------------------------#

        feat_aggreg.cuda()
        feature_2d3d = feat_aggreg(image_xyz.cuda(), points.cuda(), feature_2d.cuda())  # (b,64,np)

    cpu_feature_2d3d = feature_2d3d.squeeze().cpu() #(64,np)
    feature_2d3d_numpy = cpu_feature_2d3d.detach().numpy().T #(np,64)
    d = {'feature_2d3d' : feature_2d3d_numpy}

    return d
Exemple #8
0
def test():
    import mvpnet.data.scannet_2d3d as scannet
    import os.path as osp

    cache_dir = osp.join('/home/dchangyu/mvpnet/ScanNet/cache_rgbd')
    image_dir = osp.join('/home/dchangyu/mvpnet/ScanNet/scans_resize_160x120')

    np.random.seed(0)
    dataset = scannet.ScanNet2D3DWhole(cache_dir=cache_dir,
                                    image_dir=image_dir,
                                    split='val',
                                    #nb_pts=8192,
                                    num_rgbd_frames=20,
                                    # color_jitter=(0.5, 0.5, 0.5),
                                    # flip=0.5,
                                    # z_rot=(-180, 180),
                                    to_tensor=True
                                    )
    print(dataset)
    for i in range(len(dataset)):
        data = dataset[i]
        points = data['points'] #tensor (3,np)
        # colors = data.get('feature', None)
        for k, v in data.items():
            if isinstance(v, np.ndarray):
                print(k, v.shape, v.dtype)
            else:
                print('below not ndarray')
                print(k, v)

        images = data['images'] # (nv, 3, h, w) tensor
        image_xyz = data['image_xyz']

        print('store pc data'+str(i))

        from mvpnet.ops.group_points import group_points
        import torch
        b = 1
        nv = 20
        h = 120
        w = 160

        images = torch.from_numpy(images).unsqueeze(0) # (1, nv, 3, h, w) ndarray -> tensor
        # collapse first 2 dimensions together
        images = images.reshape([-1] + list(images.shape[2:]))

        # 2D network
        import mvpnet.models.unet_resnet34 as net2d
        net_2d = net2d.UNetResNet34(20, p=0.5 ,pretrained=True)

        checkpoint = torch.load('/home/dchangyu/mvpnet/outputs/scannet/unet_resnet34/model_080000.pth', map_location=torch.device("cpu"))
        net_2d.load_state_dict(checkpoint['model'])
        #net_2d.cuda()  #train 2d network on cpu to save cuda memory

        from common.nn.freezer import Freezer
        # build freezer
        freezer = Freezer(net_2d, ("module:net_2d", "net_2d"))
        freezer.freeze(verbose=True)  # sanity check

        preds_2d = net_2d({'image': images})
        feature_2d = preds_2d['feature']  # (b * nv, c, h, w)

        print('feature_2d from net_2d', feature_2d.shape, feature_2d.dtype, feature_2d.device)

        # unproject features
        knn_indices = torch.from_numpy(data['knn_indices']).unsqueeze(0)  # (b, np, k)  #因为knn indices里面有np,所以resample点本质是限制knnindecis的数量,从而fix unproject的数量
        feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(1, 2).contiguous()  # (b, c, nv, h, w)
        feature_2d = feature_2d.reshape(b, -1, nv * h * w)
        feature_2d = group_points(feature_2d.cuda(), knn_indices.cuda())  # (b, c, np, k)

        print('feature_2d', feature_2d.shape, feature_2d.dtype, feature_2d.device)

        # unproject depth maps #(unproject point cloud direct from frames selected)
        with torch.no_grad():
            knn_indices = torch.from_numpy(data['knn_indices']).unsqueeze(0)
            image_xyz = torch.from_numpy(data['image_xyz']).unsqueeze(0)  # (b, nv, h, w, 3)
            image_xyz = image_xyz.permute(0,4, 1, 2, 3).reshape(b, 3, nv * h * w)
            image_xyz = group_points(image_xyz.cuda(), knn_indices.cuda())  # (b, 3, np, k)

        print('image_xyz_group_points',image_xyz.shape, image_xyz.dtype, image_xyz.device)

        # 2D-3D aggregation
        points = torch.from_numpy(data['points']).unsqueeze(0)  # (b,3,np)
        feat_aggreg = FeatureAggregation(64)
        feat_aggreg.cuda()
        feature_2d3d = feat_aggreg(image_xyz.cuda(), points.cuda(), feature_2d.cuda())  # (b,64,np)

        print('feature_2d3d', feature_2d3d.shape, feature_2d3d.dtype, feature_2d3d.device)
        print(feature_2d3d)


        torch.cuda.empty_cache()
        print('cuda empty after one loop')
    def forward(self, batch, config):

        # (batch_size, num_views, 3, h, w)
        images = batch.images
        b, nv, k, h, w = images.size()
        # collapse first 2 dimensions together
        images = images.reshape([-1] + list(images.shape[2:]))

        # 2D network
        preds_2d = self.net_2d({'image': images})
        feature_2d = preds_2d[
            'feature']  # (b * nv, c, h, w) requires_grad=false
        # reshape 2d feature
        feature_2d = feature_2d.reshape(b, nv, -1, h, w).transpose(
            1, 2).contiguous()  # (b, c, nv, h, w)
        feature_2d = feature_2d.reshape(
            b, -1, nv * h * w)  #(b, 64, nv * h * w) # c=64 output channels
        feature_2d_list = []
        # reshape image_xyz
        image_xyz = batch.image_xyz  # (b, nv, h, w, 3)
        image_xyz = image_xyz.permute(0, 4, 1, 2, 3).reshape(
            b, 3, nv * h * w)  #(b, 3, nv * h * w)
        image_xyz_list = []

        knn_list = batch.knn_list
        for i in range(b):
            # unproject 2d feature for each scene
            batch_knn_indices = torch.from_numpy(knn_list[i]).long().cuda(
            )  #(1, s_np, k) #knn_indices of scene_i
            batch_feature_2d = feature_2d[i, :, :].unsqueeze(
                0)  #(1, 64, nv * h * w) # 2d feature of scene_i
            batch_feature_2d = group_points(
                batch_feature_2d, batch_knn_indices
            )  # (1, 64, s_np, k) grouped points for scene_i
            feature_2d_list.append(batch_feature_2d)

            # unproject depth maps for each scene
            with torch.no_grad():
                batch_image_xyz = image_xyz[i, :, :].unsqueeze(
                    0)  #(1, 3, nv * h * w) # image_xyz of scene_i
                batch_image_xyz = group_points(
                    batch_image_xyz, batch_knn_indices
                )  # (1, 3, s_np, k) grouped points for scene_i
                image_xyz_list.append(batch_image_xyz)

        # seen these two as point-wise feature and stacked them in point dims to adapt size of input points
        input_feature_2d = torch.cat(
            feature_2d_list,
            dim=2)  # (1, c, np, k) # 2d feature of one large stacked batch
        input_image_xyz = torch.cat(
            image_xyz_list,
            dim=2)  # (1, 3, np, k) # image_xyz of one large stacked batch

        # 2D-3D aggregation
        points = batch.feat_aggre_points.transpose(
            1, 2)  # (1,3,np) # input points of one large stacked batch
        feature_2d3d = self.feat_aggreg(input_image_xyz, points,
                                        input_feature_2d)  # (1,64,np)
        feature_2d3d = feature_2d3d.permute(0, 2, 1).reshape(-1, 64)  # (np,64)

        # stack the features with constant 1 to insure black/dark points are not ignored
        stacked_features = torch.ones_like(
            batch.feat_aggre_points[:, :, :1].squeeze(0))  # (np,1)
        # stack the 3d color feature
        stacked_features_3d = torch.cat((stacked_features, batch.colors),
                                        dim=1)  # (np,4) feature dim = 3+1

        # 2d3d feature
        # stacked_features_2d = torch.cat((stacked_features, feature_2d3d),dim=1) # (np,65) feature dim = 64+1
        stacked_features_2d = feature_2d3d  # feature dim = 64

        # Get input features
        x_2d = stacked_features_2d.clone().detach()  # feature dim: 64
        x_3d = stacked_features_3d.clone().detach()  # feature dim: 4

        # Loop over consecutive blocks

        # middle fusion
        skip_x = []
        for block_i, block_op in enumerate(self.encoder_blocks_3d):
            if block_i in self.encoder_skips:
                skip_x.append(x_3d)  # indim
            x_3d = block_op(x_3d, batch)

        index = 0
        for block_i, block_op in enumerate(self.encoder_blocks_2d):
            if block_i in self.encoder_skips:
                skip_x[index] = torch.cat([skip_x[index], x_2d],
                                          dim=1)  #cat the skip feature
                index += 1
            x_2d = block_op(x_2d, batch)

        # before decoder, fuse the features
        # x = torch.cat([x_3d, x_2d], dim=1) # indim*2
        x = torch.mean(torch.stack([x_3d, x_2d]), 0)

        for block_i, block_op in enumerate(self.decoder_blocks):
            if block_i in self.decoder_concats:
                x = torch.cat([x, skip_x.pop()], dim=1)
            x = block_op(x, batch)

        # Head of network
        x = self.head_mlp(x, batch)
        x = self.head_softmax(x, batch)

        return x