def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)

        out = self.pooling(out)
        return self.linear(out)
Exemple #2
0
    def forward(self, x):
        out_s1 = self.conv1(x)
        out_s1 = self.norm1(out_s1)
        out_s1 = self.block1(out_s1)
        out = MEF.relu(out_s1)

        out_s2 = self.conv2(out)
        out_s2 = self.norm2(out_s2)
        out_s2 = self.block2(out_s2)
        out = MEF.relu(out_s2)

        out_s4 = self.conv3(out)
        out_s4 = self.norm3(out_s4)
        out_s4 = self.block3(out_s4)
        out = MEF.relu(out_s4)

        out_s8 = self.conv4(out)
        out_s8 = self.norm4(out_s8)
        out_s8 = self.block4(out_s8)
        out = MEF.relu(out_s8)

        out = self.conv4_tr(out)
        out = self.norm4_tr(out)
        out = self.block4_tr(out)
        out_s4_tr = MEF.relu(out)

        out = ME.cat(out_s4_tr, out_s4)

        out = self.conv3_tr(out)
        out = self.norm3_tr(out)
        out = self.block3_tr(out)
        out_s2_tr = MEF.relu(out)

        out = ME.cat(out_s2_tr, out_s2)

        out = self.conv2_tr(out)
        out = self.norm2_tr(out)
        out = self.block2_tr(out)
        out_s1_tr = MEF.relu(out)

        out = ME.cat(out_s1_tr, out_s1)
        out_feat = self.conv1_tr(out)
        out_feat = MEF.relu(out_feat)
        out_feat = self.final(out_feat)

        out_att = self.conv1_tr_att(out)
        out_att = MEF.relu(out_att)
        out_att = self.final_att(out_att)
        out_att = ME.SparseTensor(self.final_att_act(out_att.F),
                                  coords_key=out_att.coords_key,
                                  coords_manager=out_att.coords_man)

        if self.normalize_feature:
            out_feat = ME.SparseTensor(
                out_feat.F / torch.norm(out_feat.F, p=2, dim=1, keepdim=True),
                coords_key=out_feat.coords_key,
                coords_manager=out_feat.coords_man)

        return dict(features=out_feat, attention=out_att)
Exemple #3
0
 def forward(self, x, x_skip):
     if x_skip is not None:
         x = ME.cat(x, x_skip)
     out_s = self.conv(x)
     if self.final is None:
         out_s = self.norm(out_s)
         out_s = self.block(out_s)
         out = MEF.relu(out_s)
         return out
     else:
         out_s = MEF.relu(out_s)
         out = self.final(out_s)
         return out
    def forward(self, flow):

        
        out =  MEF.relu(self.conv1(flow))
        out =  MEF.relu(self.conv2(out))

        out = MEF.relu(self.conv3(out))
        out = MEF.relu(self.conv4(out))

        res_flow = self.final(out)
        

        return flow + res_flow
def get_nonlinearity_fn(nonlinearity_type, input, *args, **kwargs):
  nonlinearity_type = str_to_nonlinearity_dict[nonlinearity_type]
  if nonlinearity_type == NonlinearityType.ReLU:
    return MEF.relu(input, *args, **kwargs)
  elif nonlinearity_type == NonlinearityType.ReLU:
    return MEF.leaky_relu(input, *args, **kwargs)
  elif nonlinearity_type == NonlinearityType.PReLU:
    return MEF.prelu(input, *args, **kwargs)
  elif nonlinearity_type == NonlinearityType.CELU:
    return MEF.celu(input, *args, **kwargs)
  elif nonlinearity_type == NonlinearityType.SELU:
    return MEF.selu(input, *args, **kwargs)
  else:
    raise ValueError(f'Norm type: {nonlinearity_type} not supported')
Exemple #6
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out = MEF.relu(out_s8)

    out_s16 = self.conv5(out)
    out_s16 = self.norm5(out_s16)
    out = MEF.relu(out_s16)

    out = self.conv5_tr(out)
    out = self.norm5_tr(out)
    out_s8_tr = MEF.relu(out)

    out = ME.cat((out_s8_tr, out_s8))

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat((out_s4_tr, out_s4))

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat((out_s2_tr, out_s2))

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat((out_s1_tr, out_s1))
    out = self.conv1_tr(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)
    else:
      return out
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      if ME.__version__.split(".")[1] == "5":
        return ME.SparseTensor(
            out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
            coordinate_map_key=out.coordinate_map_key,
            coordinate_manager=out.coordinate_manager)
      elif ME.__version__.split(".")[1] == "4":
        return ME.SparseTensor(
            out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
            coords_key=out.coords_key,
            coords_manager=out.coords_man)
    else:
      return out
    def forward(self, x, skip_features):
        
        out = self.conv4_tr(x)
        out = self.norm4_tr(out)
        
        out_s4_tr = self.block4_tr(out)

        out = ME.cat(out_s4_tr, skip_features[-1])

        out = self.conv3_tr(out)
        out = self.norm3_tr(out)
        out_s2_tr = self.block3_tr(out)
        
        out = ME.cat(out_s2_tr, skip_features[-2])

        out = self.conv2_tr(out)
        out = self.norm2_tr(out)
        out_s1_tr = self.block2_tr(out)
        
        out = ME.cat(out_s1_tr, skip_features[-3])

        out = self.conv1_tr(out)
        out = MEF.relu(out)
        out = self.final(out)

        return out
Exemple #9
0
    def forward(self, x):

        out_s = self.conv(x)
        out_s = self.norm(out_s)
        out_s = self.block(out_s)
        out = MEF.relu(out_s)
        return out
Exemple #10
0
    def forward(self, x):
        out_s1 = self.block1(x)
        out = MF.relu(out_s1)

        out_s2 = self.block2(out)
        out = MF.relu(out_s2)

        out_s4 = self.block3(out)
        out = MF.relu(out_s4)

        out = MF.relu(self.block3_tr(out))
        out = ME.cat(out, out_s2)

        out = MF.relu(self.block2_tr(out))
        out = ME.cat(out, out_s1)

        return self.conv1_tr(out)
Exemple #11
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat((out_s4_tr, out_s4))

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat((out_s2_tr, out_s2))

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat((out_s1_tr, out_s1))
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      out = ME.SparseTensor(
          out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)

    out = self.fc1(out)
    out = F.relu(self.bn(out))
    out = self.fc2(out)

    return out
Exemple #12
0
    def forward(self, x):
        out_s1 = self.bn1(self.conv1(x))
        out = MF.relu(out_s1)

        out_s2 = self.bn2(self.conv2(out))
        out = MF.relu(out_s2)

        out_s4 = self.bn3(self.conv3(out))
        out = MF.relu(out_s4)

        out = MF.relu(self.bn4(self.conv4(out)))
        out = ME.cat((out, out_s2))

        out = MF.relu(self.bn5(self.conv5(out)))
        out = ME.cat((out, out_s1))

        return self.conv6(out)
  def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.norm1(out)
    out = MEF.relu(out)

    out = self.conv2(out)
    out = self.norm2(out)

    if self.downsample is not None:
      residual = self.downsample(x)

    out += residual
    out = MEF.relu(out)

    return out
Exemple #14
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.pool2(out)
    out_s2 = self.conv2(out_s2)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.pool3(out)
    out_s4 = self.conv3(out_s4)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.pool4(out)
    out_s8 = self.conv4(out_s8)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8),
          coordinate_map_key=out.coordinate_map_key,
          coordinate_manager=out.coordinate_manager)
    else:
      return out
    def forward(self, x):
        
        out = self.seg_head_1(x)
        out = self.norm_1(out)
        out = MEF.relu(out)

        out = self.seg_head_2(out)


        return out
Exemple #16
0
  def forward(self, x):  # Receptive field size
    out_s1 = self.conv1(x)  # 7
    out_s1 = self.norm1(out_s1)
    out_s1 = MEF.relu(out_s1)
    out_s1 = self.block1(out_s1)

    out_s2 = self.conv2(out_s1)  # 7 + 2 * 2 = 11
    out_s2 = self.norm2(out_s2)
    out_s2 = MEF.relu(out_s2)
    out_s2 = self.block2(out_s2)  # 11 + 2 * (2 + 2) = 19

    out_s4 = self.conv3(out_s2)  # 19 + 4 * 2 = 27
    out_s4 = self.norm3(out_s4)
    out_s4 = MEF.relu(out_s4)
    out_s4 = self.block3(out_s4)  # 27 + 4 * (2 + 2) = 43

    out_s8 = self.conv4(out_s4)  # 43 + 8 * 2 = 59
    out_s8 = self.norm4(out_s8)
    out_s8 = MEF.relu(out_s8)
    out_s8 = self.block4(out_s8)  # 59 + 8 * (2 + 2) = 91

    out = self.conv4_tr(out_s8)  # 91 + 4 * 2 = 99
    out = self.norm4_tr(out)
    out = MEF.relu(out)
    out = self.block4_tr(out)  # 99 + 4 * (2 + 2) = 115

    out = ME.cat(out, out_s4)

    out = self.conv3_tr(out)  # 115 + 2 * 2 = 119
    out = self.norm3_tr(out)
    out = MEF.relu(out)
    out = self.block3_tr(out)  # 119 + 2 * (2 + 2) = 127

    out = ME.cat(out, out_s2)

    out = self.conv2_tr(out)  # 127 + 2 = 129
    out = self.norm2_tr(out)
    out = MEF.relu(out)
    out = self.block2_tr(out)  # 129 + 1 * (2 + 2) = 133

    out = ME.cat(out, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)
    else:
      return out
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)
################################################################
    out = torch.max(out.F, 0)[0]
    out = F.relu(self.fc1(out))
#    out = self.fc1(out)
    return out.unsqueeze(0)
Exemple #18
0
    def forward(self, x):
        # local feature map
        lf = self.lf(x)
        # global feature map
        gf = self.gf(lf)

        encoder_outputs = {"x": x, "lf": lf, "gf": gf}

        # encoding
        if self.config.encoder_conditional_type == "gaussian":
            ef = self.global_pool(self.ef(gf))
            mu, logvar = torch.chunk(ef.F, 2, dim=1)
            zf = self.reparameterize(mu, logvar)
            z = ME.SparseTensor(
                features=zf,
                coordinates=ef.C,
                tensor_stride=torch.tensor([2**self.config.n_conv_layers] *
                                           self.spatial_dims,
                                           device=zf.device),
                coordinate_manager=ef.coordinate_manager,
            )
            encoder_outputs.update({"z": z, "mu": mu, "logvar": logvar})
        elif self.config.encoder_conditional_type == "deterministic":
            ef = self.global_pool(self.ef(gf))
            z = ME.SparseTensor(
                features=ef.F,
                coordinates=ef.C,
                tensor_stride=torch.tensor([2**self.config.n_conv_layers] *
                                           self.spatial_dims,
                                           device=ef.device),
                coordinate_manager=ef.coordinate_manager,
            )
            encoder_outputs.update({"z": z})

        # attention features
        if self.config.use_attention:
            af = self.af(gf)
            af = F.normalize(af, p=2)
            encoder_outputs.update({"af": af})

        return encoder_outputs
Exemple #19
0
	def forward(self, stensor_src, stensor_tgt):
		################################
		# encode src
		src_s1 = self.conv1(stensor_src)
		src_s1 = self.norm1(src_s1)
		src_s1 = self.block1(src_s1)
		src = MEF.relu(src_s1)

		src_s2 = self.conv2(src)
		src_s2 = self.norm2(src_s2)
		src_s2 = self.block2(src_s2)
		src = MEF.relu(src_s2)

		src_s4 = self.conv3(src)
		src_s4 = self.norm3(src_s4)
		src_s4 = self.block3(src_s4)
		src = MEF.relu(src_s4)

		src_s8 = self.conv4(src)
		src_s8 = self.norm4(src_s8)
		src_s8 = self.block4(src_s8)
		src = MEF.relu(src_s8)


		################################
		# encode tgt
		tgt_s1 = self.conv1(stensor_tgt)
		tgt_s1 = self.norm1(tgt_s1)
		tgt_s1 = self.block1(tgt_s1)
		tgt = MEF.relu(tgt_s1)

		tgt_s2 = self.conv2(tgt)
		tgt_s2 = self.norm2(tgt_s2)
		tgt_s2 = self.block2(tgt_s2)
		tgt = MEF.relu(tgt_s2)

		tgt_s4 = self.conv3(tgt)
		tgt_s4 = self.norm3(tgt_s4)
		tgt_s4 = self.block3(tgt_s4)
		tgt = MEF.relu(tgt_s4)

		tgt_s8 = self.conv4(tgt)
		tgt_s8 = self.norm4(tgt_s8)
		tgt_s8 = self.block4(tgt_s8)
		tgt = MEF.relu(tgt_s8)


		################################
		# overlap attention module
		# empirically, when batch_size = 1, out.C[:,1:] == out.coordinates_at(0)		
		src_feats = src.F.transpose(0,1)[None,:]  #[1, C, N]
		tgt_feats = tgt.F.transpose(0,1)[None,:]  #[1, C, N]
		src_pcd, tgt_pcd = src.C[:,1:] * self.voxel_size, tgt.C[:,1:] * self.voxel_size

		# 1. project the bottleneck feature
		src_feats, tgt_feats = self.bottle(src_feats), self.bottle(tgt_feats)

		# 2. apply GNN to communicate the features and get overlap scores
		src_feats, tgt_feats= self.gnn(src_pcd.transpose(0,1)[None,:], tgt_pcd.transpose(0,1)[None,:],src_feats, tgt_feats)

		src_feats, src_scores = self.proj_gnn(src_feats), self.proj_score(src_feats)[0].transpose(0,1)
		tgt_feats, tgt_scores = self.proj_gnn(tgt_feats), self.proj_score(tgt_feats)[0].transpose(0,1)
		

		# 3. get cross-overlap scores
		src_feats_norm = F.normalize(src_feats, p=2, dim=1)[0].transpose(0,1)
		tgt_feats_norm = F.normalize(tgt_feats, p=2, dim=1)[0].transpose(0,1)
		inner_products = torch.matmul(src_feats_norm, tgt_feats_norm.transpose(0,1))
		temperature = torch.exp(self.epsilon) + 0.03
		src_scores_x = torch.matmul(F.softmax(inner_products / temperature ,dim=1) ,tgt_scores)
		tgt_scores_x = torch.matmul(F.softmax(inner_products.transpose(0,1) / temperature,dim=1),src_scores)

		# 4. update sparse tensor
		src_feats = torch.cat([src_feats[0].transpose(0,1), src_scores, src_scores_x], dim=1)
		tgt_feats = torch.cat([tgt_feats[0].transpose(0,1), tgt_scores, tgt_scores_x], dim=1)
		src = ME.SparseTensor(src_feats, 
			coordinate_map_key=src.coordinate_map_key,
			coordinate_manager=src.coordinate_manager)

		tgt = ME.SparseTensor(tgt_feats,
			coordinate_map_key=tgt.coordinate_map_key,
			coordinate_manager=tgt.coordinate_manager)


		################################
		# decoder src
		src = self.conv4_tr(src)
		src = self.norm4_tr(src)
		src = self.block4_tr(src)
		src_s4_tr = MEF.relu(src)

		src = ME.cat(src_s4_tr, src_s4)

		src = self.conv3_tr(src)
		src = self.norm3_tr(src)
		src = self.block3_tr(src)
		src_s2_tr = MEF.relu(src)

		src = ME.cat(src_s2_tr, src_s2)

		src = self.conv2_tr(src)
		src = self.norm2_tr(src)
		src = self.block2_tr(src)
		src_s1_tr = MEF.relu(src)

		src = ME.cat(src_s1_tr, src_s1)
		src = self.conv1_tr(src)
		src = MEF.relu(src)
		src = self.final(src)

		################################
		# decoder tgt
		tgt = self.conv4_tr(tgt)
		tgt = self.norm4_tr(tgt)
		tgt = self.block4_tr(tgt)
		tgt_s4_tr = MEF.relu(tgt)

		tgt = ME.cat(tgt_s4_tr, tgt_s4)

		tgt = self.conv3_tr(tgt)
		tgt = self.norm3_tr(tgt)
		tgt = self.block3_tr(tgt)
		tgt_s2_tr = MEF.relu(tgt)

		tgt = ME.cat(tgt_s2_tr, tgt_s2)

		tgt = self.conv2_tr(tgt)
		tgt = self.norm2_tr(tgt)
		tgt = self.block2_tr(tgt)
		tgt_s1_tr = MEF.relu(tgt)

		tgt = ME.cat(tgt_s1_tr, tgt_s1)
		tgt = self.conv1_tr(tgt)
		tgt = MEF.relu(tgt)
		tgt = self.final(tgt)

		################################
		# output features and scores
		sigmoid = nn.Sigmoid()
		src_feats, src_overlap, src_saliency = src.F[:,:-2], src.F[:,-2], src.F[:,-1]
		tgt_feats, tgt_overlap, tgt_saliency = tgt.F[:,:-2], tgt.F[:,-2], tgt.F[:,-1]

		src_overlap= torch.clamp(sigmoid(src_overlap.view(-1)),min=0,max=1)
		src_saliency = torch.clamp(sigmoid(src_saliency.view(-1)),min=0,max=1)
		tgt_overlap = torch.clamp(sigmoid(tgt_overlap.view(-1)),min=0,max=1)
		tgt_saliency = torch.clamp(sigmoid(tgt_saliency.view(-1)),min=0,max=1)

		src_feats = F.normalize(src_feats, p=2, dim=1)
		tgt_feats = F.normalize(tgt_feats, p=2, dim=1)

		scores_overlap = torch.cat([src_overlap, tgt_overlap], dim=0)
		scores_saliency = torch.cat([src_saliency, tgt_saliency], dim=0)

		return src_feats,  tgt_feats, scores_overlap, scores_saliency