def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out_b1 = self.block1(out) out = self.down1(out_b1) out_b2 = self.block2(out) out = self.down2(out_b2) out_b3 = self.block3(out) out = self.down3(out_b3) out = self.block4(out) out = self.up3(out) out = self.block3up(me.cat((out_b3, out))) out = self.up2(out) out = self.block2up(me.cat((out_b2, out))) out = self.up1(out) out = self.block1up(me.cat((out_b1, out))) return self.final(out)
def forward(self, x): out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) # pixel_dist=16 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) out = self.block4(out) # pixel_dist=8 out = self.convtr4p16s2(out) out = self.bntr4(out) out = self.relu(out) out = me.cat(out, out_b3p8) out = self.block5(out) # pixel_dist=4 out = self.convtr5p8s2(out) out = self.bntr5(out) out = self.relu(out) out = me.cat(out, out_b2p4) out = self.block6(out) # pixel_dist=2 out = self.convtr6p4s2(out) out = self.bntr6(out) out = self.relu(out) out = me.cat(out, out_b1p2) out = self.block7(out) # pixel_dist=1 out = self.convtr7p2s2(out) out = self.bntr7(out) out = self.relu(out) out = me.cat(out, out_p1) out = self.block8(out) return self.final(out)
def forward(self, x): out = self.conv1p1s1(x) out = self.bn1(out) out = self.relu(out) out_b1p1 = self.block1(out) out = self.conv2p1s2(out_b1p1) out = self.bn2(out) out = self.relu(out) out_b2p2 = self.block2(out) out = self.conv3p2s2(out_b2p2) out = self.bn3(out) out = self.relu(out) out_b3p4 = self.block3(out) out = self.conv4p4s2(out_b3p4) out = self.bn4(out) out = self.relu(out) # pixel_dist=8 out = self.block4(out) out = self.convtr4p8s2(out) out = self.bntr4(out) out = self.relu(out) out = me.cat((out, out_b3p4)) out = self.block5(out) out = self.convtr5p4s2(out) out = self.bntr5(out) out = self.relu(out) out = me.cat((out, out_b2p2)) out = self.block6(out) out = self.convtr6p2s2(out) out = self.bntr6(out) out = self.relu(out) out_feat = me.cat((out, out_b1p1)) out = self.final(out_feat) if self.return_feat: feat = self.mask_feat(out_feat) return feat, out return out
def forward(self, x): out = self.conv1p1s1(x) out = self.bn1(out) out = self.relu(out) out_b1p1 = self.block1(out) out = self.conv2p1s2(out_b1p1) out = self.bn2(out) out = self.relu(out) out_b2p2 = self.block2(out) out = self.conv3p2s2(out_b2p2) out = self.bn3(out) out = self.relu(out) out_b3p4 = self.block3(out) out = self.conv4p4s2(out_b3p4) out = self.bn4(out) out = self.relu(out) # pixel_dist=8 out = self.block4(out) out = self.convtr4p8s2(out) out = self.bntr4(out) out = self.relu(out) out = me.cat(out, out_b3p4) out = self.block5(out) out_5 = self.pool_tr5(out) out = self.convtr5p4s2(out) out = self.bntr5(out) out = self.relu(out) out = me.cat(out, out_b2p2) out = self.block6(out) out_6 = self.pool_tr6(out) out = self.convtr6p2s2(out) out = self.bntr6(out) out = self.relu(out) out = me.cat(out, out_b1p1, out_6, out_5) return self.final(out)
def forward(self, x): out = self.block(x) out = self.down(out) out = self.down_norm(out) out = self.intermediate(out) out = self.up(out) out = self.up_norm(out) out = MinkowskiOps.cat((out, x)) for i in range(self.reps): out = getattr(self, f'end_blocks{i}')(out) return out
def forward(self, x): out_b1 = self.relu(self.bn_down1(self.conv_down1(x))) out = self.down1(out_b1) out_b2 = self.relu(self.bn_down2(self.conv_down2(out))) out = self.down2(out_b2) out_b3 = self.relu(self.bn_down3(self.conv_down3(out))) out = self.down3(out_b3) out_b4 = self.relu(self.bn_down4(self.conv_down4(out))) out = self.down4(out_b4) out_b5 = self.relu(self.bn_down5(self.conv_down5(out))) out = self.down5(out_b5) out_b6 = self.relu(self.bn_down6(self.conv_down6(out))) out = self.down6(out_b6) out = self.relu(self.bn7(self.conv7(out))) out = self.up6(out) out = self.relu(self.bn_up6(self.conv_up6(me.cat((out_b6, out))))) out = self.up5(out) out = self.relu(self.bn_up5(self.conv_up5(me.cat((out_b5, out))))) out = self.up4(out) out = self.relu(self.bn_up4(self.conv_up4(me.cat((out_b4, out))))) out = self.up3(out) out = self.relu(self.bn_up3(self.conv_up3(me.cat((out_b3, out))))) out = self.up2(out) out = self.relu(self.bn_up2(self.conv_up2(me.cat((out_b2, out))))) out = self.up1(out) out_feat = self.relu(self.bn_up1(self.conv_up1(me.cat((out_b1, out))))) out = self.final(out_feat) if self.return_feat: feat = self.mask_feat(out_feat) return feat, out return out
def forward(self, x): out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) encoder_out = self.block4(out) out = self.convtr4p16s2(encoder_out) out = self.bntr4(out) out = self.relu(out) out = me.cat(out, out_b3p8) out = self.block5(out) out = self.convtr5p8s2(out) out = self.bntr5(out) out = self.relu(out) out = me.cat(out, out_b2p4) out = self.block6(out) out = self.convtr6p4s2(out) out = self.bntr6(out) out = self.relu(out) out = me.cat(out, out_b1p2) out = self.block7(out) out = self.convtr7p2s2(out) out = self.bntr7(out) out = self.relu(out) out = me.cat(out, out_p1) out = self.block8(out) out = self.final(out) if self.normalize_feature: return SparseTensor(out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def forward(self, x_a: ME.SparseTensor, x_b: ME.SparseTensor): """ Input: M < N xyz_1: input points position data, [B, 3, M] xyz_2: input points position data, [B, 3, N] points_1: input points data, [B, C, M] points_2: input points data, [B, C, N] interpolate xyz_2's coordinates feature with knn neighbor's features weighted by inverse distance TODO: For POINT_TR_LIKE, add support for no x_b is fed, simply upsample the x_a Return: new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ if self.POINT_TR_LIKE: dim = x_b.F.shape[1] assert dim == self.out_dim x_ac, mask_a, idx_a = separate_batch(x_a.C) B = x_ac.shape[0] N_a = x_ac.shape[1] x_af = torch.zeros(B * N_a, dim).cuda() idx_a = idx_a.reshape(-1, 1).repeat(1, dim) x_af.scatter_(dim=0, index=idx_a, src=self.linear_a(x_a.F)) x_af = x_af.reshape([B, N_a, dim]) x_bc, mask_b, idx_b = separate_batch(x_b.C) B = x_bc.shape[0] N_b = x_bc.shape[1] x_bf = torch.zeros(B * N_b, dim).cuda() idx_b = idx_b.reshape(-1, 1).repeat(1, dim) x_bf.scatter_(dim=0, index=idx_b, src=self.linear_b(x_b.F)) x_bf = x_bf.reshape([B, N_b, dim]) dists, idx = three_nn(x_bc.float(), x_ac.float()) mask = (dists.sum(dim=-1) > 0).unsqueeze(-1).repeat(1, 1, 3) dist_recip = 1.0 / (dists + 1e-1) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm weight = weight * mask # mask the zeros part interpolated_points = three_interpolate( x_af.transpose(1, 2).contiguous(), idx, weight).transpose(1, 2) # [B, N_b, dim] out = interpolated_points + x_bf out = torch.gather( out.reshape(B * N_b, dim), dim=0, index=idx_b) # should be the same size with x_a.F x = ME.SparseTensor(features=out, coordinate_map_key=x_b.coordinate_map_key, coordinate_manager=x_b.coordinate_manager) else: if self.SUM_FEATURE: x_a = self.conv_a(x_a) x_b = self.conv_b(x_b) x = x_a + x_b else: x_a = self.conv(x_a) x_a = self.bn(x_a) x_a = self.relu(x_a) x = me.cat(x_a, x_b) x = self.out_conv(x) x = self.out_bn(x) x = self.out_relu(x) return x
def forward(self, x, save_anchor=False, iter_=None, aux=None, enable_point_branch=False): # for n, m in self.named_modules(): # if 'block' in n: # if hasattr(m, "schedule_update"): # m.schedule_update(iter_) if save_anchor: self.anchors = [] # mapped to transformer.stem1 out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = get_nonlinearity_fn(self.config.nonlinearity, out) if enable_point_branch: out_p1_point = out_p1.F out_p1_coord = out_p1.C # mapped to transformer.stem2 out = self.conv1p1s2(out_p1) out = self.bn1(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) # mapped to transformer.PTBlock1 out_b1p2 = self.block1(out, iter_, aux) if save_anchor: self.anchors.append(out_b1p2) # mapped to transformer.PTBlock2 out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out_b2p4 = self.block2(out, iter_, aux) if save_anchor: self.anchors.append(out_b2p4) # mapped to transformer.PTBlock3 out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out_b3p8 = self.block3(out, iter_, aux) # if save_anchor: # self.anchors.append(out_b3p8) # pixel_dist=16 # mapped to transformer.PTBlock4 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = self.block4(out, iter_, aux) # if save_anchor: # self.anchors.append(out) if enable_point_branch: interpolated_out = self.interpolate( out, out_p1_coord.type(torch.FloatTensor).to(out.device)) # fused feature block4_features = interpolated_out + self.point_transform_mlp[0]( out_p1_point) out_fused = ME.SparseTensor(features=block4_features, coordinates=out_p1_coord) out_fused = self.downsample16x(out_fused) out = ME.SparseTensor(features=self.dropout(out_fused.F), coordinate_map_key=out.coordinate_map_key, coordinate_manager=out.coordinate_manager) # pixel_dist=8 # mapped to transfrormer.PTBlock5 out = self.convtr4p16s2(out) out = self.bntr4(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_b3p8) out = self.block5(out, iter_, aux) # out = self.block5(out) # if save_anchor: # self.anchors.append(out) # pixel_dist=4 # mapped to transformer.PTBlock6 out = self.convtr5p8s2(out) out = self.bntr5(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_b2p4) out = self.block6(out, iter_, aux) if save_anchor: self.anchors.append(out) if enable_point_branch: interpolated_out = self.interpolate( out, out_p1_coord.type(torch.FloatTensor).to(out.device)) block6_features = interpolated_out + self.point_transform_mlp[1]( block4_features) out_fused = ME.SparseTensor(features=block6_features, coordinates=out_p1_coord) out_fused = self.downsample4x(out_fused) out = ME.SparseTensor(features=self.dropout(out_fused.F), coordinate_map_key=out.coordinate_map_key, coordinate_manager=out.coordinate_manager) # pixel_dist=2 # mapped to transformer.PTBlock7 out = self.convtr6p4s2(out) out = self.bntr6(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_b1p2) out = self.block7(out, iter_, aux) if save_anchor: self.anchors.append(out) # pixel_dist=1 # mapped to transformer.final_conv out = self.convtr7p2s2(out) out = self.bntr7(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_p1) out = self.block8(out, iter_, aux) if enable_point_branch: interpolated_out = self.interpolate( out, out_p1_coord.type(torch.FloatTensor).to(out.device)) block8_features = interpolated_out + self.point_transform_mlp[2]( block6_features) out_fused = ME.SparseTensor(features=block8_features, coordinates=out_p1_coord) out = ME.SparseTensor(features=self.dropout(out_fused.F), coordinate_map_key=out.coordinate_map_key, coordinate_manager=out.coordinate_manager) out = self.final(out) if torch.isnan(out.F).sum() > 0: import ipdb ipdb.set_trace() if save_anchor: return out, self.anchors else: return out
def forward(self, x, out_feat_keys=None): end_points = {} out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) end_points["en0_features"] = out ## 32 out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) end_points["en1_features"] = out ## 32 out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) end_points["en2_features"] = out ## 64 # pixel_dist=16 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) end_points["en3_features"] = out ## 128 out = self.block4(out) # pixel_dist=8 out = self.convtr4p16s2(out) out = self.bntr4(out) out = self.relu(out) end_points["en4_features"] = out ## 256 out = me.cat(out, out_b3p8) out = self.block5(out) # pixel_dist=4 out = self.convtr5p8s2(out) out = self.bntr5(out) out = self.relu(out) end_points["plane4_features"] = out out = me.cat(out, out_b2p4) out = self.block6(out) # pixel_dist=2 out = self.convtr6p4s2(out) out = self.bntr6(out) out = self.relu(out) end_points["plane5_features"] = out out = me.cat(out, out_b1p2) out = self.block7(out) # pixel_dist=1 out = self.convtr7p2s2(out) out = self.bntr7(out) out = self.relu(out) end_points["plane6_features"] = out out = me.cat(out, out_p1) out = self.block8(out) end_points["plane7_features"] = out out_feats = [None] * len(out_feat_keys) for key in out_feat_keys: feat = end_points[key + "_features"] org_feat = end_points[key + "_features"] feat = self.maxpool(feat) if self.use_mlp: feat = self.head(feat) out_feats[out_feat_keys.index(key)] = feat.F ### Just use smlp return out_feats
def forward(self, x, save_anchor=False): if save_anchor: self.anchors = [] # mapped to transformer.stem1 out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = get_nonlinearity_fn(self.config.nonlinearity, out) # mapped to transformer.stem2 out = self.conv1p1s2(out_p1) out = self.bn1(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) # mapped to transformer.PTBlock1 out_b1p2 = self.block1(out) if save_anchor: self.anchors.append(out_b1p2) # mapped to transformer.PTBlock2 out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out_b2p4 = self.block2(out) if save_anchor: self.anchors.append(out_b2p4) # mapped to transformer.PTBlock3 out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out_b3p8 = self.block3(out) if save_anchor: self.anchors.append(out_b3p8) # pixel_dist=16 # mapped to transformer.PTBlock4 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = self.block4(out) if save_anchor: self.anchors.append(out) # pixel_dist=8 # mapped to transfrormer.PTBlock5 out = self.convtr4p16s2(out) out = self.bntr4(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_b3p8) out = self.block5(out) if save_anchor: self.anchors.append(out) # pixel_dist=4 # mapped to transformer.PTBlock6 out = self.convtr5p8s2(out) out = self.bntr5(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_b2p4) out = self.block6(out) if save_anchor: self.anchors.append(out) # pixel_dist=2 # mapped to transformer.PTBlock7 out = self.convtr6p4s2(out) out = self.bntr6(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_b1p2) out = self.block7(out) if save_anchor: self.anchors.append(out) # pixel_dist=1 # mapped to transformer.final_conv out = self.convtr7p2s2(out) out = self.bntr7(out) out = get_nonlinearity_fn(self.config.nonlinearity, out) out = me.cat(out, out_p1) out = self.block8(out) if save_anchor: return self.final(out), self.anchors else: return self.final(out)