Exemplo n.º 1
0
Arquivo: unet.py Projeto: jgwak/GSDN
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out_b1 = self.block1(out)
        out = self.down1(out_b1)

        out_b2 = self.block2(out)
        out = self.down2(out_b2)

        out_b3 = self.block3(out)
        out = self.down3(out_b3)

        out = self.block4(out)
        out = self.up3(out)

        out = self.block3up(me.cat((out_b3, out)))

        out = self.up2(out)
        out = self.block2up(me.cat((out_b2, out)))

        out = self.up1(out)
        out = self.block1up(me.cat((out_b1, out)))

        return self.final(out)
    def forward(self, x):
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)

        # pixel_dist=16
        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        out = self.block4(out)

        # pixel_dist=8
        out = self.convtr4p16s2(out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = me.cat(out, out_b3p8)
        out = self.block5(out)

        # pixel_dist=4
        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = me.cat(out, out_b2p4)
        out = self.block6(out)

        # pixel_dist=2
        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out = me.cat(out, out_b1p2)
        out = self.block7(out)

        # pixel_dist=1
        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = self.relu(out)

        out = me.cat(out, out_p1)
        out = self.block8(out)

        return self.final(out)
Exemplo n.º 3
0
    def forward(self, x):
        out = self.conv1p1s1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out_b1p1 = self.block1(out)

        out = self.conv2p1s2(out_b1p1)
        out = self.bn2(out)
        out = self.relu(out)

        out_b2p2 = self.block2(out)

        out = self.conv3p2s2(out_b2p2)
        out = self.bn3(out)
        out = self.relu(out)

        out_b3p4 = self.block3(out)

        out = self.conv4p4s2(out_b3p4)
        out = self.bn4(out)
        out = self.relu(out)

        # pixel_dist=8
        out = self.block4(out)

        out = self.convtr4p8s2(out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = me.cat((out, out_b3p4))
        out = self.block5(out)

        out = self.convtr5p4s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = me.cat((out, out_b2p2))
        out = self.block6(out)

        out = self.convtr6p2s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out_feat = me.cat((out, out_b1p1))
        out = self.final(out_feat)

        if self.return_feat:
            feat = self.mask_feat(out_feat)
            return feat, out
        return out
Exemplo n.º 4
0
    def forward(self, x):
        out = self.conv1p1s1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out_b1p1 = self.block1(out)

        out = self.conv2p1s2(out_b1p1)
        out = self.bn2(out)
        out = self.relu(out)

        out_b2p2 = self.block2(out)

        out = self.conv3p2s2(out_b2p2)
        out = self.bn3(out)
        out = self.relu(out)

        out_b3p4 = self.block3(out)

        out = self.conv4p4s2(out_b3p4)
        out = self.bn4(out)
        out = self.relu(out)

        # pixel_dist=8
        out = self.block4(out)

        out = self.convtr4p8s2(out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = me.cat(out, out_b3p4)
        out = self.block5(out)
        out_5 = self.pool_tr5(out)

        out = self.convtr5p4s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = me.cat(out, out_b2p2)
        out = self.block6(out)
        out_6 = self.pool_tr6(out)

        out = self.convtr6p2s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out = me.cat(out, out_b1p1, out_6, out_5)
        return self.final(out)
Exemplo n.º 5
0
 def forward(self, x):
     out = self.block(x)
     out = self.down(out)
     out = self.down_norm(out)
     out = self.intermediate(out)
     out = self.up(out)
     out = self.up_norm(out)
     out = MinkowskiOps.cat((out, x))
     for i in range(self.reps):
         out = getattr(self, f'end_blocks{i}')(out)
     return out
Exemplo n.º 6
0
Arquivo: unet.py Projeto: jgwak/GSDN
    def forward(self, x):
        out_b1 = self.relu(self.bn_down1(self.conv_down1(x)))
        out = self.down1(out_b1)
        out_b2 = self.relu(self.bn_down2(self.conv_down2(out)))
        out = self.down2(out_b2)
        out_b3 = self.relu(self.bn_down3(self.conv_down3(out)))
        out = self.down3(out_b3)
        out_b4 = self.relu(self.bn_down4(self.conv_down4(out)))
        out = self.down4(out_b4)
        out_b5 = self.relu(self.bn_down5(self.conv_down5(out)))
        out = self.down5(out_b5)
        out_b6 = self.relu(self.bn_down6(self.conv_down6(out)))
        out = self.down6(out_b6)
        out = self.relu(self.bn7(self.conv7(out)))
        out = self.up6(out)
        out = self.relu(self.bn_up6(self.conv_up6(me.cat((out_b6, out)))))
        out = self.up5(out)
        out = self.relu(self.bn_up5(self.conv_up5(me.cat((out_b5, out)))))
        out = self.up4(out)
        out = self.relu(self.bn_up4(self.conv_up4(me.cat((out_b4, out)))))
        out = self.up3(out)
        out = self.relu(self.bn_up3(self.conv_up3(me.cat((out_b3, out)))))
        out = self.up2(out)
        out = self.relu(self.bn_up2(self.conv_up2(me.cat((out_b2, out)))))
        out = self.up1(out)
        out_feat = self.relu(self.bn_up1(self.conv_up1(me.cat((out_b1, out)))))
        out = self.final(out_feat)

        if self.return_feat:
            feat = self.mask_feat(out_feat)
            return feat, out

        return out
Exemplo n.º 7
0
    def forward(self, x):
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)

        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        encoder_out = self.block4(out)

        out = self.convtr4p16s2(encoder_out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = me.cat(out, out_b3p8)
        out = self.block5(out)

        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = me.cat(out, out_b2p4)
        out = self.block6(out)

        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out = me.cat(out, out_b1p2)
        out = self.block7(out)

        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = self.relu(out)

        out = me.cat(out, out_p1)
        out = self.block8(out)

        out = self.final(out)

        if self.normalize_feature:
            return SparseTensor(out.F /
                                torch.norm(out.F, p=2, dim=1, keepdim=True),
                                coords_key=out.coords_key,
                                coords_manager=out.coords_man)
        else:
            return out
    def forward(self, x_a: ME.SparseTensor, x_b: ME.SparseTensor):
        """
        Input:
            M < N
            xyz_1: input points position data, [B, 3, M]
            xyz_2: input points position data, [B, 3, N]
            points_1: input points data, [B, C, M]
            points_2: input points data, [B, C, N]

            interpolate xyz_2's coordinates feature with knn neighbor's features weighted by inverse distance

            TODO: For POINT_TR_LIKE, add support for no x_b is fed, simply upsample the x_a

        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """

        if self.POINT_TR_LIKE:

            dim = x_b.F.shape[1]
            assert dim == self.out_dim

            x_ac, mask_a, idx_a = separate_batch(x_a.C)
            B = x_ac.shape[0]
            N_a = x_ac.shape[1]
            x_af = torch.zeros(B * N_a, dim).cuda()
            idx_a = idx_a.reshape(-1, 1).repeat(1, dim)
            x_af.scatter_(dim=0, index=idx_a, src=self.linear_a(x_a.F))
            x_af = x_af.reshape([B, N_a, dim])

            x_bc, mask_b, idx_b = separate_batch(x_b.C)
            B = x_bc.shape[0]
            N_b = x_bc.shape[1]
            x_bf = torch.zeros(B * N_b, dim).cuda()
            idx_b = idx_b.reshape(-1, 1).repeat(1, dim)
            x_bf.scatter_(dim=0, index=idx_b, src=self.linear_b(x_b.F))
            x_bf = x_bf.reshape([B, N_b, dim])

            dists, idx = three_nn(x_bc.float(), x_ac.float())

            mask = (dists.sum(dim=-1) > 0).unsqueeze(-1).repeat(1, 1, 3)

            dist_recip = 1.0 / (dists + 1e-1)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            weight = weight * mask  # mask the zeros part

            interpolated_points = three_interpolate(
                x_af.transpose(1, 2).contiguous(), idx,
                weight).transpose(1, 2)  # [B, N_b, dim]
            out = interpolated_points + x_bf

            out = torch.gather(
                out.reshape(B * N_b, dim), dim=0,
                index=idx_b)  # should be the same size with x_a.F
            x = ME.SparseTensor(features=out,
                                coordinate_map_key=x_b.coordinate_map_key,
                                coordinate_manager=x_b.coordinate_manager)

        else:
            if self.SUM_FEATURE:
                x_a = self.conv_a(x_a)
                x_b = self.conv_b(x_b)
                x = x_a + x_b
            else:
                x_a = self.conv(x_a)
                x_a = self.bn(x_a)
                x_a = self.relu(x_a)
                x = me.cat(x_a, x_b)
                x = self.out_conv(x)
                x = self.out_bn(x)
                x = self.out_relu(x)

        return x
Exemplo n.º 9
0
    def forward(self,
                x,
                save_anchor=False,
                iter_=None,
                aux=None,
                enable_point_branch=False):
        # for n, m in self.named_modules():
        # if 'block' in n:
        # if hasattr(m, "schedule_update"):
        # m.schedule_update(iter_)
        if save_anchor:
            self.anchors = []
        # mapped to transformer.stem1
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = get_nonlinearity_fn(self.config.nonlinearity, out)

        if enable_point_branch:
            out_p1_point = out_p1.F
            out_p1_coord = out_p1.C

        # mapped to transformer.stem2
        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)

        # mapped to transformer.PTBlock1
        out_b1p2 = self.block1(out, iter_, aux)
        if save_anchor:
            self.anchors.append(out_b1p2)

        # mapped to transformer.PTBlock2
        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)
        out_b2p4 = self.block2(out, iter_, aux)
        if save_anchor:
            self.anchors.append(out_b2p4)

        # mapped to transformer.PTBlock3
        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)
        out_b3p8 = self.block3(out, iter_, aux)
        # if save_anchor:
        # self.anchors.append(out_b3p8)

        # pixel_dist=16
        # mapped to transformer.PTBlock4
        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)
        out = self.block4(out, iter_, aux)
        # if save_anchor:
        # self.anchors.append(out)

        if enable_point_branch:
            interpolated_out = self.interpolate(
                out,
                out_p1_coord.type(torch.FloatTensor).to(out.device))
            # fused feature
            block4_features = interpolated_out + self.point_transform_mlp[0](
                out_p1_point)
            out_fused = ME.SparseTensor(features=block4_features,
                                        coordinates=out_p1_coord)
            out_fused = self.downsample16x(out_fused)
            out = ME.SparseTensor(features=self.dropout(out_fused.F),
                                  coordinate_map_key=out.coordinate_map_key,
                                  coordinate_manager=out.coordinate_manager)

        # pixel_dist=8
        # mapped to transfrormer.PTBlock5
        out = self.convtr4p16s2(out)
        out = self.bntr4(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)
        out = me.cat(out, out_b3p8)
        out = self.block5(out, iter_, aux)
        # out = self.block5(out)
        # if save_anchor:
        # self.anchors.append(out)

        # pixel_dist=4
        # mapped to transformer.PTBlock6
        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)
        out = me.cat(out, out_b2p4)
        out = self.block6(out, iter_, aux)
        if save_anchor:
            self.anchors.append(out)

        if enable_point_branch:
            interpolated_out = self.interpolate(
                out,
                out_p1_coord.type(torch.FloatTensor).to(out.device))
            block6_features = interpolated_out + self.point_transform_mlp[1](
                block4_features)
            out_fused = ME.SparseTensor(features=block6_features,
                                        coordinates=out_p1_coord)
            out_fused = self.downsample4x(out_fused)
            out = ME.SparseTensor(features=self.dropout(out_fused.F),
                                  coordinate_map_key=out.coordinate_map_key,
                                  coordinate_manager=out.coordinate_manager)

        # pixel_dist=2
        # mapped to transformer.PTBlock7
        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)
        out = me.cat(out, out_b1p2)
        out = self.block7(out, iter_, aux)
        if save_anchor:
            self.anchors.append(out)

        # pixel_dist=1
        # mapped to transformer.final_conv
        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = get_nonlinearity_fn(self.config.nonlinearity, out)

        out = me.cat(out, out_p1)
        out = self.block8(out, iter_, aux)

        if enable_point_branch:
            interpolated_out = self.interpolate(
                out,
                out_p1_coord.type(torch.FloatTensor).to(out.device))
            block8_features = interpolated_out + self.point_transform_mlp[2](
                block6_features)
            out_fused = ME.SparseTensor(features=block8_features,
                                        coordinates=out_p1_coord)
            out = ME.SparseTensor(features=self.dropout(out_fused.F),
                                  coordinate_map_key=out.coordinate_map_key,
                                  coordinate_manager=out.coordinate_manager)

        out = self.final(out)

        if torch.isnan(out.F).sum() > 0:
            import ipdb
            ipdb.set_trace()

        if save_anchor:
            return out, self.anchors
        else:
            return out
Exemplo n.º 10
0
    def forward(self, x, out_feat_keys=None):

        end_points = {}

        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)

        end_points["en0_features"] = out  ## 32

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)

        end_points["en1_features"] = out  ## 32

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)

        end_points["en2_features"] = out  ## 64

        # pixel_dist=16
        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        end_points["en3_features"] = out  ## 128
        out = self.block4(out)

        # pixel_dist=8
        out = self.convtr4p16s2(out)
        out = self.bntr4(out)
        out = self.relu(out)

        end_points["en4_features"] = out  ## 256

        out = me.cat(out, out_b3p8)
        out = self.block5(out)

        # pixel_dist=4
        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        end_points["plane4_features"] = out
        out = me.cat(out, out_b2p4)
        out = self.block6(out)

        # pixel_dist=2
        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        end_points["plane5_features"] = out
        out = me.cat(out, out_b1p2)
        out = self.block7(out)

        # pixel_dist=1
        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = self.relu(out)

        end_points["plane6_features"] = out
        out = me.cat(out, out_p1)
        out = self.block8(out)

        end_points["plane7_features"] = out

        out_feats = [None] * len(out_feat_keys)

        for key in out_feat_keys:
            feat = end_points[key + "_features"]
            org_feat = end_points[key + "_features"]

            feat = self.maxpool(feat)
            if self.use_mlp:
                feat = self.head(feat)
            out_feats[out_feat_keys.index(key)] = feat.F  ### Just use smlp
        return out_feats
  def forward(self, x, save_anchor=False):
    if save_anchor:
        self.anchors = []
    # mapped to transformer.stem1
    out = self.conv0p1s1(x)
    out = self.bn0(out)
    out_p1 = get_nonlinearity_fn(self.config.nonlinearity, out)
    # mapped to transformer.stem2
    out = self.conv1p1s2(out_p1)
    out = self.bn1(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)
    
    # mapped to transformer.PTBlock1
    out_b1p2 = self.block1(out)
    if save_anchor:
        self.anchors.append(out_b1p2)

    # mapped to transformer.PTBlock2
    out = self.conv2p2s2(out_b1p2)
    out = self.bn2(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)
    out_b2p4 = self.block2(out)
    if save_anchor:
        self.anchors.append(out_b2p4)

    # mapped to transformer.PTBlock3
    out = self.conv3p4s2(out_b2p4)
    out = self.bn3(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)
    out_b3p8 = self.block3(out)
    if save_anchor:
        self.anchors.append(out_b3p8)

    # pixel_dist=16
    # mapped to transformer.PTBlock4
    out = self.conv4p8s2(out_b3p8)
    out = self.bn4(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)
    out = self.block4(out)
    if save_anchor:
        self.anchors.append(out)


    # pixel_dist=8
    # mapped to transfrormer.PTBlock5
    out = self.convtr4p16s2(out)
    out = self.bntr4(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)
    out = me.cat(out, out_b3p8)
    out = self.block5(out)
    if save_anchor:
      self.anchors.append(out)

    # pixel_dist=4
    # mapped to transformer.PTBlock6
    out = self.convtr5p8s2(out)
    out = self.bntr5(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)
    out = me.cat(out, out_b2p4)
    out = self.block6(out)
    if save_anchor:
      self.anchors.append(out)

    # pixel_dist=2
    # mapped to transformer.PTBlock7
    out = self.convtr6p4s2(out)
    out = self.bntr6(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)
    out = me.cat(out, out_b1p2)
    out = self.block7(out)
    if save_anchor:
      self.anchors.append(out)

    # pixel_dist=1
    # mapped to transformer.final_conv
    out = self.convtr7p2s2(out)
    out = self.bntr7(out)
    out = get_nonlinearity_fn(self.config.nonlinearity, out)

    out = me.cat(out, out_p1)
    out = self.block8(out)

    if save_anchor:
        return self.final(out), self.anchors
    else:
        return self.final(out)