def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = F.relu(out) out = self.conv2(out) out = self.bn2(out) out = F.relu(out) out = self.pooling(out) return self.linear(out)
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out_feat = self.conv1_tr(out) out_feat = MEF.relu(out_feat) out_feat = self.final(out_feat) out_att = self.conv1_tr_att(out) out_att = MEF.relu(out_att) out_att = self.final_att(out_att) out_att = ME.SparseTensor(self.final_att_act(out_att.F), coords_key=out_att.coords_key, coords_manager=out_att.coords_man) if self.normalize_feature: out_feat = ME.SparseTensor( out_feat.F / torch.norm(out_feat.F, p=2, dim=1, keepdim=True), coords_key=out_feat.coords_key, coords_manager=out_feat.coords_man) return dict(features=out_feat, attention=out_att)
def forward(self, x, x_skip): if x_skip is not None: x = ME.cat(x, x_skip) out_s = self.conv(x) if self.final is None: out_s = self.norm(out_s) out_s = self.block(out_s) out = MEF.relu(out_s) return out else: out_s = MEF.relu(out_s) out = self.final(out_s) return out
def forward(self, flow): out = MEF.relu(self.conv1(flow)) out = MEF.relu(self.conv2(out)) out = MEF.relu(self.conv3(out)) out = MEF.relu(self.conv4(out)) res_flow = self.final(out) return flow + res_flow
def get_nonlinearity_fn(nonlinearity_type, input, *args, **kwargs): nonlinearity_type = str_to_nonlinearity_dict[nonlinearity_type] if nonlinearity_type == NonlinearityType.ReLU: return MEF.relu(input, *args, **kwargs) elif nonlinearity_type == NonlinearityType.ReLU: return MEF.leaky_relu(input, *args, **kwargs) elif nonlinearity_type == NonlinearityType.PReLU: return MEF.prelu(input, *args, **kwargs) elif nonlinearity_type == NonlinearityType.CELU: return MEF.celu(input, *args, **kwargs) elif nonlinearity_type == NonlinearityType.SELU: return MEF.selu(input, *args, **kwargs) else: raise ValueError(f'Norm type: {nonlinearity_type} not supported')
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out = MEF.relu(out_s8) out_s16 = self.conv5(out) out_s16 = self.norm5(out_s16) out = MEF.relu(out_s16) out = self.conv5_tr(out) out = self.norm5_tr(out) out_s8_tr = MEF.relu(out) out = ME.cat((out_s8_tr, out_s8)) out = self.conv4_tr(out) out = self.norm4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat((out_s4_tr, out_s4)) out = self.conv3_tr(out) out = self.norm3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat((out_s2_tr, out_s2)) out = self.conv2_tr(out) out = self.norm2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat((out_s1_tr, out_s1)) out = self.conv1_tr(out) if self.normalize_feature: return ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: if ME.__version__.split(".")[1] == "5": return ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coordinate_map_key=out.coordinate_map_key, coordinate_manager=out.coordinate_manager) elif ME.__version__.split(".")[1] == "4": return ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def forward(self, x, skip_features): out = self.conv4_tr(x) out = self.norm4_tr(out) out_s4_tr = self.block4_tr(out) out = ME.cat(out_s4_tr, skip_features[-1]) out = self.conv3_tr(out) out = self.norm3_tr(out) out_s2_tr = self.block3_tr(out) out = ME.cat(out_s2_tr, skip_features[-2]) out = self.conv2_tr(out) out = self.norm2_tr(out) out_s1_tr = self.block2_tr(out) out = ME.cat(out_s1_tr, skip_features[-3]) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) return out
def forward(self, x): out_s = self.conv(x) out_s = self.norm(out_s) out_s = self.block(out_s) out = MEF.relu(out_s) return out
def forward(self, x): out_s1 = self.block1(x) out = MF.relu(out_s1) out_s2 = self.block2(out) out = MF.relu(out_s2) out_s4 = self.block3(out) out = MF.relu(out_s4) out = MF.relu(self.block3_tr(out)) out = ME.cat(out, out_s2) out = MF.relu(self.block2_tr(out)) out = ME.cat(out, out_s1) return self.conv1_tr(out)
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat((out_s4_tr, out_s4)) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat((out_s2_tr, out_s2)) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat((out_s1_tr, out_s1)) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: out = ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) out = self.fc1(out) out = F.relu(self.bn(out)) out = self.fc2(out) return out
def forward(self, x): out_s1 = self.bn1(self.conv1(x)) out = MF.relu(out_s1) out_s2 = self.bn2(self.conv2(out)) out = MF.relu(out_s2) out_s4 = self.bn3(self.conv3(out)) out = MF.relu(out_s4) out = MF.relu(self.bn4(self.conv4(out))) out = ME.cat((out, out_s2)) out = MF.relu(self.bn5(self.conv5(out))) out = ME.cat((out, out_s1)) return self.conv6(out)
def forward(self, x): residual = x out = self.conv1(x) out = self.norm1(out) out = MEF.relu(out) out = self.conv2(out) out = self.norm2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = MEF.relu(out) return out
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.pool2(out) out_s2 = self.conv2(out_s2) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.pool3(out) out_s4 = self.conv3(out_s4) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.pool4(out) out_s8 = self.conv4(out_s8) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: return ME.SparseTensor( out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8), coordinate_map_key=out.coordinate_map_key, coordinate_manager=out.coordinate_manager) else: return out
def forward(self, x): out = self.seg_head_1(x) out = self.norm_1(out) out = MEF.relu(out) out = self.seg_head_2(out) return out
def forward(self, x): # Receptive field size out_s1 = self.conv1(x) # 7 out_s1 = self.norm1(out_s1) out_s1 = MEF.relu(out_s1) out_s1 = self.block1(out_s1) out_s2 = self.conv2(out_s1) # 7 + 2 * 2 = 11 out_s2 = self.norm2(out_s2) out_s2 = MEF.relu(out_s2) out_s2 = self.block2(out_s2) # 11 + 2 * (2 + 2) = 19 out_s4 = self.conv3(out_s2) # 19 + 4 * 2 = 27 out_s4 = self.norm3(out_s4) out_s4 = MEF.relu(out_s4) out_s4 = self.block3(out_s4) # 27 + 4 * (2 + 2) = 43 out_s8 = self.conv4(out_s4) # 43 + 8 * 2 = 59 out_s8 = self.norm4(out_s8) out_s8 = MEF.relu(out_s8) out_s8 = self.block4(out_s8) # 59 + 8 * (2 + 2) = 91 out = self.conv4_tr(out_s8) # 91 + 4 * 2 = 99 out = self.norm4_tr(out) out = MEF.relu(out) out = self.block4_tr(out) # 99 + 4 * (2 + 2) = 115 out = ME.cat(out, out_s4) out = self.conv3_tr(out) # 115 + 2 * 2 = 119 out = self.norm3_tr(out) out = MEF.relu(out) out = self.block3_tr(out) # 119 + 2 * (2 + 2) = 127 out = ME.cat(out, out_s2) out = self.conv2_tr(out) # 127 + 2 = 129 out = self.norm2_tr(out) out = MEF.relu(out) out = self.block2_tr(out) # 129 + 1 * (2 + 2) = 133 out = ME.cat(out, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: return ME.SparseTensor( out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) ################################################################ out = torch.max(out.F, 0)[0] out = F.relu(self.fc1(out)) # out = self.fc1(out) return out.unsqueeze(0)
def forward(self, x): # local feature map lf = self.lf(x) # global feature map gf = self.gf(lf) encoder_outputs = {"x": x, "lf": lf, "gf": gf} # encoding if self.config.encoder_conditional_type == "gaussian": ef = self.global_pool(self.ef(gf)) mu, logvar = torch.chunk(ef.F, 2, dim=1) zf = self.reparameterize(mu, logvar) z = ME.SparseTensor( features=zf, coordinates=ef.C, tensor_stride=torch.tensor([2**self.config.n_conv_layers] * self.spatial_dims, device=zf.device), coordinate_manager=ef.coordinate_manager, ) encoder_outputs.update({"z": z, "mu": mu, "logvar": logvar}) elif self.config.encoder_conditional_type == "deterministic": ef = self.global_pool(self.ef(gf)) z = ME.SparseTensor( features=ef.F, coordinates=ef.C, tensor_stride=torch.tensor([2**self.config.n_conv_layers] * self.spatial_dims, device=ef.device), coordinate_manager=ef.coordinate_manager, ) encoder_outputs.update({"z": z}) # attention features if self.config.use_attention: af = self.af(gf) af = F.normalize(af, p=2) encoder_outputs.update({"af": af}) return encoder_outputs
def forward(self, stensor_src, stensor_tgt): ################################ # encode src src_s1 = self.conv1(stensor_src) src_s1 = self.norm1(src_s1) src_s1 = self.block1(src_s1) src = MEF.relu(src_s1) src_s2 = self.conv2(src) src_s2 = self.norm2(src_s2) src_s2 = self.block2(src_s2) src = MEF.relu(src_s2) src_s4 = self.conv3(src) src_s4 = self.norm3(src_s4) src_s4 = self.block3(src_s4) src = MEF.relu(src_s4) src_s8 = self.conv4(src) src_s8 = self.norm4(src_s8) src_s8 = self.block4(src_s8) src = MEF.relu(src_s8) ################################ # encode tgt tgt_s1 = self.conv1(stensor_tgt) tgt_s1 = self.norm1(tgt_s1) tgt_s1 = self.block1(tgt_s1) tgt = MEF.relu(tgt_s1) tgt_s2 = self.conv2(tgt) tgt_s2 = self.norm2(tgt_s2) tgt_s2 = self.block2(tgt_s2) tgt = MEF.relu(tgt_s2) tgt_s4 = self.conv3(tgt) tgt_s4 = self.norm3(tgt_s4) tgt_s4 = self.block3(tgt_s4) tgt = MEF.relu(tgt_s4) tgt_s8 = self.conv4(tgt) tgt_s8 = self.norm4(tgt_s8) tgt_s8 = self.block4(tgt_s8) tgt = MEF.relu(tgt_s8) ################################ # overlap attention module # empirically, when batch_size = 1, out.C[:,1:] == out.coordinates_at(0) src_feats = src.F.transpose(0,1)[None,:] #[1, C, N] tgt_feats = tgt.F.transpose(0,1)[None,:] #[1, C, N] src_pcd, tgt_pcd = src.C[:,1:] * self.voxel_size, tgt.C[:,1:] * self.voxel_size # 1. project the bottleneck feature src_feats, tgt_feats = self.bottle(src_feats), self.bottle(tgt_feats) # 2. apply GNN to communicate the features and get overlap scores src_feats, tgt_feats= self.gnn(src_pcd.transpose(0,1)[None,:], tgt_pcd.transpose(0,1)[None,:],src_feats, tgt_feats) src_feats, src_scores = self.proj_gnn(src_feats), self.proj_score(src_feats)[0].transpose(0,1) tgt_feats, tgt_scores = self.proj_gnn(tgt_feats), self.proj_score(tgt_feats)[0].transpose(0,1) # 3. get cross-overlap scores src_feats_norm = F.normalize(src_feats, p=2, dim=1)[0].transpose(0,1) tgt_feats_norm = F.normalize(tgt_feats, p=2, dim=1)[0].transpose(0,1) inner_products = torch.matmul(src_feats_norm, tgt_feats_norm.transpose(0,1)) temperature = torch.exp(self.epsilon) + 0.03 src_scores_x = torch.matmul(F.softmax(inner_products / temperature ,dim=1) ,tgt_scores) tgt_scores_x = torch.matmul(F.softmax(inner_products.transpose(0,1) / temperature,dim=1),src_scores) # 4. update sparse tensor src_feats = torch.cat([src_feats[0].transpose(0,1), src_scores, src_scores_x], dim=1) tgt_feats = torch.cat([tgt_feats[0].transpose(0,1), tgt_scores, tgt_scores_x], dim=1) src = ME.SparseTensor(src_feats, coordinate_map_key=src.coordinate_map_key, coordinate_manager=src.coordinate_manager) tgt = ME.SparseTensor(tgt_feats, coordinate_map_key=tgt.coordinate_map_key, coordinate_manager=tgt.coordinate_manager) ################################ # decoder src src = self.conv4_tr(src) src = self.norm4_tr(src) src = self.block4_tr(src) src_s4_tr = MEF.relu(src) src = ME.cat(src_s4_tr, src_s4) src = self.conv3_tr(src) src = self.norm3_tr(src) src = self.block3_tr(src) src_s2_tr = MEF.relu(src) src = ME.cat(src_s2_tr, src_s2) src = self.conv2_tr(src) src = self.norm2_tr(src) src = self.block2_tr(src) src_s1_tr = MEF.relu(src) src = ME.cat(src_s1_tr, src_s1) src = self.conv1_tr(src) src = MEF.relu(src) src = self.final(src) ################################ # decoder tgt tgt = self.conv4_tr(tgt) tgt = self.norm4_tr(tgt) tgt = self.block4_tr(tgt) tgt_s4_tr = MEF.relu(tgt) tgt = ME.cat(tgt_s4_tr, tgt_s4) tgt = self.conv3_tr(tgt) tgt = self.norm3_tr(tgt) tgt = self.block3_tr(tgt) tgt_s2_tr = MEF.relu(tgt) tgt = ME.cat(tgt_s2_tr, tgt_s2) tgt = self.conv2_tr(tgt) tgt = self.norm2_tr(tgt) tgt = self.block2_tr(tgt) tgt_s1_tr = MEF.relu(tgt) tgt = ME.cat(tgt_s1_tr, tgt_s1) tgt = self.conv1_tr(tgt) tgt = MEF.relu(tgt) tgt = self.final(tgt) ################################ # output features and scores sigmoid = nn.Sigmoid() src_feats, src_overlap, src_saliency = src.F[:,:-2], src.F[:,-2], src.F[:,-1] tgt_feats, tgt_overlap, tgt_saliency = tgt.F[:,:-2], tgt.F[:,-2], tgt.F[:,-1] src_overlap= torch.clamp(sigmoid(src_overlap.view(-1)),min=0,max=1) src_saliency = torch.clamp(sigmoid(src_saliency.view(-1)),min=0,max=1) tgt_overlap = torch.clamp(sigmoid(tgt_overlap.view(-1)),min=0,max=1) tgt_saliency = torch.clamp(sigmoid(tgt_saliency.view(-1)),min=0,max=1) src_feats = F.normalize(src_feats, p=2, dim=1) tgt_feats = F.normalize(tgt_feats, p=2, dim=1) scores_overlap = torch.cat([src_overlap, tgt_overlap], dim=0) scores_saliency = torch.cat([src_saliency, tgt_saliency], dim=0) return src_feats, tgt_feats, scores_overlap, scores_saliency