def forward(self, x, **kwargs): x0 = self.stem(x) # 30504 x1 = self.stage1(x0) # 8039 x2 = self.stage2(x1) # 2029 x3 = self.stage3(x2) # 489 x4 = self.stage4(x3) # 119 y1 = self.up1[0](x4) y1 = ME.cat([y1, x3]) y1 = self.up1[1](y1) y2 = self.up2[0](y1) y2 = ME.cat([y2, x2]) y2 = self.up2[1](y2) y3 = self.up3[0](y2) y3 = ME.cat([y3, x1]) y3 = self.up3[1](y3) y4 = self.up4[0](y3) y4 = ME.cat([y4, x0]) y4 = self.up4[1](y4) out = self.classifier(y4) if 'mink' in kwargs and kwargs['mink']: return out else: return out.F
def forward(self, x, skip_features): out = self.conv4_tr(x) out = self.norm4_tr(out) out_s4_tr = self.block4_tr(out) out = ME.cat(out_s4_tr, skip_features[-1]) out = self.conv3_tr(out) out = self.norm3_tr(out) out_s2_tr = self.block3_tr(out) out = ME.cat(out_s2_tr, skip_features[-2]) out = self.conv2_tr(out) out = self.norm2_tr(out) out_s1_tr = self.block2_tr(out) out = ME.cat(out_s1_tr, skip_features[-3]) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) return out
def forward(self, x, **kargs): x0 = self.stem(x) x1 = self.stage1(x0) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) y1 = self.up1[0](x4) y1 = ME.cat([y1, x3]) y1 = self.up1[1](y1) y2 = self.up2[0](y1) y2 = ME.cat([y2, x2]) y2 = self.up2[1](y2) y3 = self.up3[0](y2) y3 = ME.cat([y3, x1]) y3 = self.up3[1](y3) y4 = self.up4[0](y3) y4 = ME.cat([y4, x0]) y4 = self.up4[1](y4) out = self.classifier(y4) return out
def forward(self, x: ME.TensorField): x = self.mlp1(x) y = x.sparse() y = self.conv1(y) y1 = self.pool(y) y = self.conv2(y1) y2 = self.pool(y) y = self.conv3(y2) y3 = self.pool(y) y = self.conv4(y3) y4 = self.pool(y) x1 = y1.slice(x) x2 = y2.slice(x) x3 = y3.slice(x) x4 = y4.slice(x) x = ME.cat(x1, x2, x3, x4) y = self.conv5(x.sparse()) x1 = self.global_max_pool(y) x2 = self.global_avg_pool(y) return self.final(ME.cat(x1, x2)).F
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out_feat = self.conv1_tr(out) out_feat = MEF.relu(out_feat) out_feat = self.final(out_feat) out_att = self.conv1_tr_att(out) out_att = MEF.relu(out_att) out_att = self.final_att(out_att) out_att = ME.SparseTensor(self.final_att_act(out_att.F), coords_key=out_att.coords_key, coords_manager=out_att.coords_man) if self.normalize_feature: out_feat = ME.SparseTensor( out_feat.F / torch.norm(out_feat.F, p=2, dim=1, keepdim=True), coords_key=out_feat.coords_key, coords_manager=out_feat.coords_man) return dict(features=out_feat, attention=out_att)
def forward(self, x): out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) # tensor_stride=16 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) out = self.block4(out) # tensor_stride=8 out = self.convtr4p16s2(out) out = self.bntr4(out) out = self.relu(out) out = ME.cat(out, out_b3p8) out = self.block5(out) # tensor_stride=4 out = self.convtr5p8s2(out) out = self.bntr5(out) out = self.relu(out) out = ME.cat(out, out_b2p4) out = self.block6(out) # tensor_stride=2 out = self.convtr6p4s2(out) out = self.bntr6(out) out = self.relu(out) out = ME.cat(out, out_b1p2) out = self.block7(out) # tensor_stride=1 out = self.convtr7p2s2(out) out = self.bntr7(out) out = self.relu(out) out = ME.cat(out, out_p1) out = self.block8(out) return self.final(out)
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: if ME.__version__.split(".")[1] == "5": return ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coordinate_map_key=out.coordinate_map_key, coordinate_manager=out.coordinate_manager) elif ME.__version__.split(".")[1] == "4": return ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat((out_s4_tr, out_s4)) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat((out_s2_tr, out_s2)) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat((out_s1_tr, out_s1)) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: out = ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) out = self.fc1(out) out = F.relu(self.bn(out)) out = self.fc2(out) return out
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.pool2(out) out_s2 = self.conv2(out_s2) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.pool3(out) out_s4 = self.conv3(out_s4) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.pool4(out) out_s8 = self.conv4(out_s8) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: return ME.SparseTensor( out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8), coordinate_map_key=out.coordinate_map_key, coordinate_manager=out.coordinate_manager) else: return out
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out = MEF.relu(out_s8) out_s16 = self.conv5(out) out_s16 = self.norm5(out_s16) out = MEF.relu(out_s16) out = self.conv5_tr(out) out = self.norm5_tr(out) out_s8_tr = MEF.relu(out) out = ME.cat((out_s8_tr, out_s8)) out = self.conv4_tr(out) out = self.norm4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat((out_s4_tr, out_s4)) out = self.conv3_tr(out) out = self.norm3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat((out_s2_tr, out_s2)) out = self.conv2_tr(out) out = self.norm2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat((out_s1_tr, out_s1)) out = self.conv1_tr(out) if self.normalize_feature: return ME.SparseTensor( out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def forward(self, x): # Receptive field size out_s1 = self.conv1(x) # 7 out_s1 = self.norm1(out_s1) out_s1 = MEF.relu(out_s1) out_s1 = self.block1(out_s1) out_s2 = self.conv2(out_s1) # 7 + 2 * 2 = 11 out_s2 = self.norm2(out_s2) out_s2 = MEF.relu(out_s2) out_s2 = self.block2(out_s2) # 11 + 2 * (2 + 2) = 19 out_s4 = self.conv3(out_s2) # 19 + 4 * 2 = 27 out_s4 = self.norm3(out_s4) out_s4 = MEF.relu(out_s4) out_s4 = self.block3(out_s4) # 27 + 4 * (2 + 2) = 43 out_s8 = self.conv4(out_s4) # 43 + 8 * 2 = 59 out_s8 = self.norm4(out_s8) out_s8 = MEF.relu(out_s8) out_s8 = self.block4(out_s8) # 59 + 8 * (2 + 2) = 91 out = self.conv4_tr(out_s8) # 91 + 4 * 2 = 99 out = self.norm4_tr(out) out = MEF.relu(out) out = self.block4_tr(out) # 99 + 4 * (2 + 2) = 115 out = ME.cat(out, out_s4) out = self.conv3_tr(out) # 115 + 2 * 2 = 119 out = self.norm3_tr(out) out = MEF.relu(out) out = self.block3_tr(out) # 119 + 2 * (2 + 2) = 127 out = ME.cat(out, out_s2) out = self.conv2_tr(out) # 127 + 2 = 129 out = self.norm2_tr(out) out = MEF.relu(out) out = self.block2_tr(out) # 129 + 1 * (2 + 2) = 133 out = ME.cat(out, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) if self.normalize_feature: return ME.SparseTensor( out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8), coords_key=out.coords_key, coords_manager=out.coords_man) else: return out
def forward(self, x): out_s1 = self.conv1(x) out_s1 = self.norm1(out_s1) out_s1 = self.block1(out_s1) out = MEF.relu(out_s1) out_s2 = self.conv2(out) out_s2 = self.norm2(out_s2) out_s2 = self.block2(out_s2) out = MEF.relu(out_s2) out_s4 = self.conv3(out) out_s4 = self.norm3(out_s4) out_s4 = self.block3(out_s4) out = MEF.relu(out_s4) out_s8 = self.conv4(out) out_s8 = self.norm4(out_s8) out_s8 = self.block4(out_s8) out = MEF.relu(out_s8) out = self.conv4_tr(out) out = self.norm4_tr(out) out = self.block4_tr(out) out_s4_tr = MEF.relu(out) out = ME.cat(out_s4_tr, out_s4) out = self.conv3_tr(out) out = self.norm3_tr(out) out = self.block3_tr(out) out_s2_tr = MEF.relu(out) out = ME.cat(out_s2_tr, out_s2) out = self.conv2_tr(out) out = self.norm2_tr(out) out = self.block2_tr(out) out_s1_tr = MEF.relu(out) out = ME.cat(out_s1_tr, out_s1) out = self.conv1_tr(out) out = MEF.relu(out) out = self.final(out) ################################################################ out = torch.max(out.F, 0)[0] out = F.relu(self.fc1(out)) # out = self.fc1(out) return out.unsqueeze(0)
def forward(self, x): y = self.conv(x) if self.inner_module: y = self.inner_module(y) y = self.convtr(y) y = ME.cat(x, y) return self.cat_conv(y)
def forward(self, x): out0 = self.conv0_1(self.relu(self.conv0_0(x))) out1 = self.conv1_2(self.relu(self.conv1_1(self.relu( self.conv1_0(x))))) out = ME.cat(out0, out1) + x return out
def forward(self, x): residual = self.residual(x) cat = tuple([layer(x) for layer in self.paths]) out = ME.cat(cat) out = self.linear(out) out += residual return out
def forward(self, x): out_s1 = self.block1(x) out = MF.relu(out_s1) out_s2 = self.block2(out) out = MF.relu(out_s2) out_s4 = self.block3(out) out = MF.relu(out_s4) out = MF.relu(self.block3_tr(out)) out = ME.cat(out, out_s2) out = MF.relu(self.block2_tr(out)) out = ME.cat(out, out_s1) return self.conv1_tr(out)
def forward(self, x): out_s1 = self.bn1(self.conv1(x)) out = MF.relu(out_s1) out_s2 = self.bn2(self.conv2(out)) out = MF.relu(out_s2) out_s4 = self.bn3(self.conv3(out)) out = MF.relu(out_s4) out = MF.relu(self.bn4(self.conv4(out))) out = ME.cat((out, out_s2)) out = MF.relu(self.bn5(self.conv5(out))) out = ME.cat((out, out_s1)) return self.conv6(out)
def forward(self, x, x_skip): residual, x = super().forward(x) x = self.upsample(residual) + x if self.skip: return ME.cat(x, x_skip) else: return x
def forward(self, x, x_skip=None): if x_skip is not None: out = ME.cat(x, x_skip) else: out = x out = self.conv_tr(out) out = self.bn(out) out = self.activation(out) out = self.block(out) return out
def forward(self, x, x_skip): if x_skip is not None: x = ME.cat(x, x_skip) out_s = self.conv(x) if self.final is None: out_s = self.norm(out_s) out = self.block(out_s) return out else: out_s = MEF.relu(out_s) out = self.final(out_s) return out
def forward(self, input): cat = [] for i, pool in enumerate(self.pool): x = pool(input) # First item is Global Pooling if i == 0: x = self.unpool[i](input, x) else: x = self.unpool[i](x) cat.append(x) out = ME.cat(cat) out = self.linear(out) return out
def sparse_tensor_arithmetics(): coords0, feats0 = to_sparse_coo(data_batch_0) coords0, feats0 = ME.utils.sparse_collate(coords=[coords0], feats=[feats0]) coords1, feats1 = to_sparse_coo(data_batch_1) coords1, feats1 = ME.utils.sparse_collate(coords=[coords1], feats=[feats1]) # sparse tensors A = ME.SparseTensor(coordinates=coords0, features=feats0) B = ME.SparseTensor(coordinates=coords1, features=feats1) # The following fails try: C = A + B except AssertionError: pass B = ME.SparseTensor( coordinates=coords1, features=feats1, coordinate_manager=A. coordinate_manager # must share the same coordinate manager ) C = A + B C = A - B C = A * B C = A / B # in place operations # Note that it requires the same coords_key (no need to feed coords) D = ME.SparseTensor( # coords=coords, not required features=feats0, coordinate_manager=A. coordinate_manager, # must share the same coordinate manager coordinate_map_key=A. coordinate_map_key # For inplace, must share the same coords key ) A += D A -= D A *= D A /= D # If you have two or more sparse tensors with the same coords_key, you can concatenate features E = ME.cat(A, D)
def forward(self, input): coords = input[:, 0:self.D + 1].cpu().int() features = input[:, self.D + 1:].float() x = ME.SparseTensor(features, coords=coords) encoderOutput = self.encoder(x) decoderTensors = self.decoder(encoderOutput['finalTensor'], encoderOutput['encoderTensors']) sppFeatures = self.spp(decoderTensors[-1]) finalFeatures = ME.cat((decoderTensors[-1], sppFeatures)) res = { 'encoderTensors': encoderOutput['encoderTensors'], 'decoderTensors': decoderTensors, 'finalFeatures': finalFeatures } return res
def decoder(self, final, encoderTensors): ''' Vanilla UResNet Decoder INPUTS: - encoderTensors (list of SparseTensor): output of encoder. RETURNS: - decoderTensors (list of SparseTensor): list of feature tensors in decoding path at each spatial resolution. ''' decoderTensors = [] x = final for i, layer in enumerate(self.decoding_conv): eTensor = encoderTensors[-i - 2] x = layer(x) x = ME.cat((eTensor, x)) x = self.decoding_block[i](x) decoderTensors.append(x) return decoderTensors
def forward(self, input, lang_feat): x1 = self.en1(input) x2 = self.en2(x1) x3 = self.en3(x2) cm = x3.coordinate_manager coords = cm.get_coordinates(8) batch_coords = coords[:, 0] lang_feat_cast = lang_feat[batch_coords.long(), :] lang_sparse = ME.SparseTensor(features=lang_feat_cast, coordinate_map_key=x3.coordinate_map_key, coordinate_manager=cm) x3 = ME.cat(x3, lang_sparse) x3 = self.fuse_layer(x3) d3 = self.de3(x2, x3) d2 = self.de2(x1, d3) d1 = self.de1(input, d2) net = self.final_layer(d1) return net
def forward(self, stensor_src, stensor_tgt): ################################ # encode src src_s1 = self.conv1(stensor_src) src_s1 = self.norm1(src_s1) src_s1 = self.block1(src_s1) src = MEF.relu(src_s1) src_s2 = self.conv2(src) src_s2 = self.norm2(src_s2) src_s2 = self.block2(src_s2) src = MEF.relu(src_s2) src_s4 = self.conv3(src) src_s4 = self.norm3(src_s4) src_s4 = self.block3(src_s4) src = MEF.relu(src_s4) src_s8 = self.conv4(src) src_s8 = self.norm4(src_s8) src_s8 = self.block4(src_s8) src = MEF.relu(src_s8) ################################ # encode tgt tgt_s1 = self.conv1(stensor_tgt) tgt_s1 = self.norm1(tgt_s1) tgt_s1 = self.block1(tgt_s1) tgt = MEF.relu(tgt_s1) tgt_s2 = self.conv2(tgt) tgt_s2 = self.norm2(tgt_s2) tgt_s2 = self.block2(tgt_s2) tgt = MEF.relu(tgt_s2) tgt_s4 = self.conv3(tgt) tgt_s4 = self.norm3(tgt_s4) tgt_s4 = self.block3(tgt_s4) tgt = MEF.relu(tgt_s4) tgt_s8 = self.conv4(tgt) tgt_s8 = self.norm4(tgt_s8) tgt_s8 = self.block4(tgt_s8) tgt = MEF.relu(tgt_s8) ################################ # overlap attention module # empirically, when batch_size = 1, out.C[:,1:] == out.coordinates_at(0) src_feats = src.F.transpose(0,1)[None,:] #[1, C, N] tgt_feats = tgt.F.transpose(0,1)[None,:] #[1, C, N] src_pcd, tgt_pcd = src.C[:,1:] * self.voxel_size, tgt.C[:,1:] * self.voxel_size # 1. project the bottleneck feature src_feats, tgt_feats = self.bottle(src_feats), self.bottle(tgt_feats) # 2. apply GNN to communicate the features and get overlap scores src_feats, tgt_feats= self.gnn(src_pcd.transpose(0,1)[None,:], tgt_pcd.transpose(0,1)[None,:],src_feats, tgt_feats) src_feats, src_scores = self.proj_gnn(src_feats), self.proj_score(src_feats)[0].transpose(0,1) tgt_feats, tgt_scores = self.proj_gnn(tgt_feats), self.proj_score(tgt_feats)[0].transpose(0,1) # 3. get cross-overlap scores src_feats_norm = F.normalize(src_feats, p=2, dim=1)[0].transpose(0,1) tgt_feats_norm = F.normalize(tgt_feats, p=2, dim=1)[0].transpose(0,1) inner_products = torch.matmul(src_feats_norm, tgt_feats_norm.transpose(0,1)) temperature = torch.exp(self.epsilon) + 0.03 src_scores_x = torch.matmul(F.softmax(inner_products / temperature ,dim=1) ,tgt_scores) tgt_scores_x = torch.matmul(F.softmax(inner_products.transpose(0,1) / temperature,dim=1),src_scores) # 4. update sparse tensor src_feats = torch.cat([src_feats[0].transpose(0,1), src_scores, src_scores_x], dim=1) tgt_feats = torch.cat([tgt_feats[0].transpose(0,1), tgt_scores, tgt_scores_x], dim=1) src = ME.SparseTensor(src_feats, coordinate_map_key=src.coordinate_map_key, coordinate_manager=src.coordinate_manager) tgt = ME.SparseTensor(tgt_feats, coordinate_map_key=tgt.coordinate_map_key, coordinate_manager=tgt.coordinate_manager) ################################ # decoder src src = self.conv4_tr(src) src = self.norm4_tr(src) src = self.block4_tr(src) src_s4_tr = MEF.relu(src) src = ME.cat(src_s4_tr, src_s4) src = self.conv3_tr(src) src = self.norm3_tr(src) src = self.block3_tr(src) src_s2_tr = MEF.relu(src) src = ME.cat(src_s2_tr, src_s2) src = self.conv2_tr(src) src = self.norm2_tr(src) src = self.block2_tr(src) src_s1_tr = MEF.relu(src) src = ME.cat(src_s1_tr, src_s1) src = self.conv1_tr(src) src = MEF.relu(src) src = self.final(src) ################################ # decoder tgt tgt = self.conv4_tr(tgt) tgt = self.norm4_tr(tgt) tgt = self.block4_tr(tgt) tgt_s4_tr = MEF.relu(tgt) tgt = ME.cat(tgt_s4_tr, tgt_s4) tgt = self.conv3_tr(tgt) tgt = self.norm3_tr(tgt) tgt = self.block3_tr(tgt) tgt_s2_tr = MEF.relu(tgt) tgt = ME.cat(tgt_s2_tr, tgt_s2) tgt = self.conv2_tr(tgt) tgt = self.norm2_tr(tgt) tgt = self.block2_tr(tgt) tgt_s1_tr = MEF.relu(tgt) tgt = ME.cat(tgt_s1_tr, tgt_s1) tgt = self.conv1_tr(tgt) tgt = MEF.relu(tgt) tgt = self.final(tgt) ################################ # output features and scores sigmoid = nn.Sigmoid() src_feats, src_overlap, src_saliency = src.F[:,:-2], src.F[:,-2], src.F[:,-1] tgt_feats, tgt_overlap, tgt_saliency = tgt.F[:,:-2], tgt.F[:,-2], tgt.F[:,-1] src_overlap= torch.clamp(sigmoid(src_overlap.view(-1)),min=0,max=1) src_saliency = torch.clamp(sigmoid(src_saliency.view(-1)),min=0,max=1) tgt_overlap = torch.clamp(sigmoid(tgt_overlap.view(-1)),min=0,max=1) tgt_saliency = torch.clamp(sigmoid(tgt_saliency.view(-1)),min=0,max=1) src_feats = F.normalize(src_feats, p=2, dim=1) tgt_feats = F.normalize(tgt_feats, p=2, dim=1) scores_overlap = torch.cat([src_overlap, tgt_overlap], dim=0) scores_saliency = torch.cat([src_saliency, tgt_saliency], dim=0) return src_feats, tgt_feats, scores_overlap, scores_saliency
def forward(self, x, skip): if skip is not None: inp = ME.cat(x, skip) else: inp = x return super().forward(inp)
def forward(self, x_a, x_b): x_a = self.conv_a(x_a) x = ME.cat(x_a, x_b) x = self.conv_proj(x) return x
def cat(*args): return ME.cat(*args)
def forward(self, x, **kwargs): out = self.conv0p1s1(x) out = self.bn0(out) out_p1 = self.relu(out) out = self.conv1p1s2(out_p1) out = self.bn1(out) out = self.relu(out) out_b1p2 = self.block1(out) out = self.conv2p2s2(out_b1p2) out = self.bn2(out) out = self.relu(out) out_b2p4 = self.block2(out) out = self.conv3p4s2(out_b2p4) out = self.bn3(out) out = self.relu(out) out_b3p8 = self.block3(out) # tensor_stride=16 out = self.conv4p8s2(out_b3p8) out = self.bn4(out) out = self.relu(out) out = self.block4(out) # tensor_stride=8 out = self.convtr4p16s2(out) out = self.bntr4(out) out = self.relu(out) out = ME.cat((out, out_b3p8)) out = self.block5(out) # tensor_stride=4 out = self.convtr5p8s2(out) out = self.bntr5(out) out = self.relu(out) out = ME.cat((out, out_b2p4)) out = self.block6(out) # tensor_stride=2 out = self.convtr6p4s2(out) out = self.bntr6(out) out = self.relu(out) out = ME.cat((out, out_b1p2)) out = self.block7(out) # tensor_stride=1 out = self.convtr7p2s2(out) out = self.bntr7(out) out = self.relu(out) out = ME.cat((out, out_p1)) out = self.block8(out) out = self.final(out) if 'mink' in kwargs and kwargs['mink']: return out else: return out.F