Exemplo n.º 1
0
    def forward(self, p):
        batch_size, T, D = p.size()

        # acquire the index for each point
        coord = {}
        index = {}
        if 'xz' in self.plane_type:
            coord['xz'] = normalize_coordinate(p.clone(),
                                               plane='xz',
                                               padding=self.padding)
            index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
        if 'xy' in self.plane_type:
            coord['xy'] = normalize_coordinate(p.clone(),
                                               plane='xy',
                                               padding=self.padding)
            index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
        if 'yz' in self.plane_type:
            coord['yz'] = normalize_coordinate(p.clone(),
                                               plane='yz',
                                               padding=self.padding)
            index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
        if 'grid' in self.plane_type:
            coord['grid'] = normalize_3d_coordinate(p.clone(),
                                                    padding=self.padding)
            index['grid'] = coordinate2index(coord['grid'],
                                             self.reso_grid,
                                             coord_type='3d')

        ##################
        if self.pos_encoding:
            pp = self.pe(p)
            net = self.fc_pos(pp)
        else:
            net = self.fc_pos(p)
        ##################
        #net = self.fc_pos(p)

        net = self.blocks[0](net)
        for block in self.blocks[1:]:
            pooled = self.pool_local(coord, index, net)
            net = torch.cat([net, pooled], dim=2)
            net = block(net)

        c = self.fc_c(net)

        fea = {}
        #if 'grid' in self.plane_type:
        #    fea = {**self.generate_grid_features(p, c)}
        if 'grid' in self.plane_type:
            fea['grid'] = self.generate_grid_features(p, c)
        if 'xz' in self.plane_type:
            fea['xz'] = self.generate_plane_features(p, c, plane='xz')
        if 'xy' in self.plane_type:
            fea['xy'] = self.generate_plane_features(p, c, plane='xy')
        if 'yz' in self.plane_type:
            fea['yz'] = self.generate_plane_features(p, c, plane='yz')

        return fea
Exemplo n.º 2
0
    def generate_grid_features(self, p, c):
        p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding)
        index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
        # scatter grid features from points
        fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
        c = c.permute(0, 2, 1)
        fea_grid = scatter_mean(c, index, out=fea_grid)
        fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid)

        if self.unet3d is not None:
            fea_grid = self.unet3d(fea_grid)

        return fea_grid
Exemplo n.º 3
0
    def generate_plane_features(self, p, c, plane='xz'):
        # acquire indices of features in plane
        xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding)
        index = coordinate2index(xy, self.reso_plane)

        # scatter plane features from points
        fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
        c = c.permute(0, 2, 1)
        fea_plane = scatter_mean(c, index, out=fea_plane)
        fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane)

        # process the plane features with UNet
        if self.unet is not None:
            fea_plane = self.unet(fea_plane)

        return fea_plane
    def generate_dynamic_plane_features(self, p, c, normal_feature,
                                        basis_normalizer_matrix):
        # acquire indices of features in plane
        xy = normalize_dynamic_plane_coordinate(
            p.clone(), basis_normalizer_matrix,
            padding=self.padding)  # normalize to the range of (0, 1)
        index = coordinate2index(xy, self.reso_plane)

        # scatter plane features from points
        fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)

        c = c.permute(0, 2, 1)  # B x 512 x T
        c = c + normal_feature.unsqueeze(2)
        fea_plane = scatter_mean(c, index, out=fea_plane)  # B x 512 x reso^2
        fea_plane = fea_plane.reshape(
            p.size(0), self.c_dim, self.reso_plane,
            self.reso_plane)  # sparce matrix (B x 512 x reso x reso)

        # process the plane features with UNet
        if self.unet is not None:
            fea_plane = self.unet(fea_plane)

        return fea_plane
    def forward(self, p):
        batch_size, T, D = p.size()
        self.device = 'cuda'

        ##################
        if self.pos_encoding:
            pp = self.pe(p)
            net = self.fc_pos(pp)
            net_pl = self.fc_plane_net(pp)
        else:
            net = self.fc_pos(p)
            net_pl = self.fc_plane_net(p)
        ##################

        normal_fea = []
        normal_fea_hdim = {}

        num_planes = self.num_channels
        num_dynamic_planes = self.num_channels - 3

        for l in range(num_dynamic_planes):
            normal_fea.append(self.plane_params[l](self.actvn(net_pl)))
            normal_fea_hdim['dynamic_plane{}'.format(
                l)] = self.plane_params_hdim[l](normal_fea[l])

        if (self.num_channels == 3):
            raise Exception(
                f"Number of channels is {self.num_channels}, no point in using Hybrid approach"
            )
        self.plane_parameters = torch.stack(
            normal_fea,
            dim=1)  # plane parameter (batch_size x num_dynamic_planes x 3)
        C_mat = ChangeBasis(
            self.plane_parameters, device=self.device
        )  # change of basis and normalizer matrix (concatenated)

        # acquire the index for each point
        coord = {}
        index = {}

        for l in range(num_planes):
            if l == 0:
                coord['plane{}'.format(l)] = normalize_coordinate(
                    p.clone(), plane='xz', padding=self.padding)
                index['plane{}'.format(l)] = coordinate2index(
                    coord['plane{}'.format(l)], self.reso_plane)
            elif l == 1:
                coord['plane{}'.format(l)] = normalize_coordinate(
                    p.clone(), plane='xy', padding=self.padding)
                index['plane{}'.format(l)] = coordinate2index(
                    coord['plane{}'.format(l)], self.reso_plane)
            elif l == 2:
                coord['plane{}'.format(l)] = normalize_coordinate(
                    p.clone(), plane='yz', padding=self.padding)
                index['plane{}'.format(l)] = coordinate2index(
                    coord['plane{}'.format(l)], self.reso_plane)
            else:
                dynamic_plane_id = l - 3
                coord['plane{}'.format(
                    l)] = normalize_dynamic_plane_coordinate(
                        p.clone(),
                        C_mat[:, dynamic_plane_id],
                        padding=self.padding)
                index['plane{}'.format(l)] = coordinate2index(
                    coord['plane{}'.format(l)], self.reso_plane)

        net = self.blocks[0](net)
        for block in self.blocks[1:]:
            pooled = self.pool_local(coord, index, net)
            net = torch.cat([net, pooled], dim=2)
            net = block(net)

        c = self.fc_c(net)

        fea = {}

        for l in range(num_planes):
            if l == 0:
                fea['plane{}'.format(l)] = self.generate_plane_features(
                    p, c, plane='xz')
            elif l == 1:
                fea['plane{}'.format(l)] = self.generate_plane_features(
                    p, c, plane='xy')
            elif l == 2:
                fea['plane{}'.format(l)] = self.generate_plane_features(
                    p, c, plane='yz')
            else:
                dynamic_plane_id = l - 3
                fea['plane{}'.format(
                    l)] = self.generate_dynamic_plane_features(
                        p, c, normal_fea_hdim['dynamic_plane{}'.format(
                            dynamic_plane_id)], C_mat[:, dynamic_plane_id])

        fea['c_mat'] = C_mat
        # Normalize plane params for similarity loss calculation
        eye_basis = torch.tensor([[1., 0., 0.], [0., 1., 0.],
                                  [0., 0., 1.]]).to(self.device)
        canonical_plane_parameters = torch.cat(batch_size * [eye_basis]).view(
            batch_size, 3, 3)

        self.plane_parameters = self.plane_parameters.reshape(
            [batch_size * num_dynamic_planes, 3])
        self.plane_parameters = self.plane_parameters / torch.norm(
            self.plane_parameters, p=2, dim=1).view(
                batch_size * num_dynamic_planes, 1)  # normalize
        self.plane_parameters = self.plane_parameters.view(batch_size, -1)
        self.plane_parameters = self.plane_parameters.view(batch_size, -1, 3)

        # Concatenate canonical plane normals and dynamic plane parameters
        self.plane_parameters = torch.cat(
            [canonical_plane_parameters, self.plane_parameters], dim=1)

        return fea
    def forward(self, p):
        batch_size, T, D = p.size()
        self.device = 'cpu'

        ##################
        if self.pos_encoding:
            pp = self.pe(p)
            net = self.fc_pos(pp)
            net_pl = self.fc_plane_net(pp)
        else:
            net = self.fc_pos(p)
            net_pl = self.fc_plane_net(p)
        ##################

        normal_fea = []
        normal_fea_hdim = {}

        for l in range(self.num_channels):
            normal_fea.append(self.plane_params[l](self.actvn(net_pl)))
            normal_fea_hdim['plane{}'.format(l)] = self.plane_params_hdim[l](
                normal_fea[l])

        self.plane_parameters = torch.stack(
            normal_fea, dim=1)  # plane parameter (batch_size x L x 3)
        C_mat = ChangeBasis(
            self.plane_parameters, device=self.device
        )  # change of basis and normalizer matrix (concatenated)
        num_planes = C_mat.size()[1]

        # acquire the index for each point
        coord = {}
        index = {}

        for l in range(num_planes):
            coord['plane{}'.format(l)] = normalize_dynamic_plane_coordinate(
                p.clone(), C_mat[:, l], padding=self.padding)
            index['plane{}'.format(l)] = coordinate2index(
                coord['plane{}'.format(l)], self.reso_plane)

        net = self.blocks[0](net)
        for block in self.blocks[1:]:
            pooled = self.pool_local(coord, index, net)
            net = torch.cat([net, pooled], dim=2)
            net = block(net)

        c = self.fc_c(net)

        fea = {}

        for l in range(C_mat.size()[1]):
            fea['plane{}'.format(l)] = self.generate_dynamic_plane_features(
                p, c, normal_fea_hdim['plane{}'.format(l)], C_mat[:, l])

        fea['c_mat'] = C_mat

        # Normalize plane params for similarity loss calculation
        self.plane_parameters = self.plane_parameters.reshape(
            [batch_size * num_planes, 3])
        self.plane_parameters = self.plane_parameters / torch.norm(
            self.plane_parameters, p=2, dim=1).view(batch_size * num_planes,
                                                    1)  # normalize
        self.plane_parameters = self.plane_parameters.view(batch_size, -1)
        self.plane_parameters = self.plane_parameters.view(batch_size, -1, 3)
        # print("just fea", type(fea))
        return fea
Exemplo n.º 7
0
    def forward(self, p):
        batch_size, T, D = p.size()

        p1 = p[::2]
        p2 = p[1::2]

        p12 = torch.cat([p1, p2], dim=1)

        # acquire the index for each point
        coord = {}
        index = {}

        if 'grid' in self.plane_type:
            # First scan
            coord['grid_1'] = normalize_3d_coordinate(p1.clone(),
                                                      padding=self.padding)
            index['grid_1'] = coordinate2index(coord['grid_1'],
                                               self.reso_grid,
                                               coord_type='3d')

            # Second scan
            coord['grid_2'] = normalize_3d_coordinate(p2.clone(),
                                                      padding=self.padding)
            index['grid_2'] = coordinate2index(coord['grid_2'],
                                               self.reso_grid,
                                               coord_type='3d')

            # First + Second
            coord['grid_12'] = normalize_3d_coordinate(p12.clone(),
                                                       padding=self.padding)
            index['grid_12'] = coordinate2index(coord['grid_12'],
                                                self.reso_grid,
                                                coord_type='3d')

        # Encode first scan
        net_1 = self.fc_pos(p1)
        net_1 = self.blocks[0](net_1)

        for block in self.blocks[1:]:
            pooled = self.pool_local(index['grid_1'], net_1)
            net_1 = torch.cat([net_1, pooled], dim=2)
            net_1 = block(net_1)

        c1 = self.fc_c(net_1)

        # Encode second scan
        net_2 = self.fc_pos(p2)
        net_2 = self.blocks[0](net_2)

        for block in self.blocks[1:]:
            pooled = self.pool_local(index['grid_2'], net_2)
            net_2 = torch.cat([net_2, pooled], dim=2)
            net_2 = block(net_2)

        c2 = self.fc_c(net_2)

        # Encode first + second scan
        net_12 = self.fc_pos(p12)
        net_12 = self.blocks[0](net_12)

        for block in self.blocks[1:]:
            pooled = self.pool_local(index['grid_12'], net_12)
            net_12 = torch.cat([net_12, pooled], dim=2)
            net_12 = block(net_12)

        c12 = self.fc_c(net_12)

        fea = {}
        fea['unet3d'] = self.unet3d
        if 'grid' in self.plane_type:
            fea['latent_1'] = self.generate_grid_features(p1, c1)
            fea['latent_2'] = self.generate_grid_features(p2, c2)
            fea['latent_12'] = self.generate_grid_features(p12, c12)

        return fea