Ejemplo n.º 1
0
    def forward(self, x, **kwargs):
        x0 = self.stem(x)     # 30504
        x1 = self.stage1(x0)  # 8039
        x2 = self.stage2(x1)  # 2029
        x3 = self.stage3(x2)  # 489
        x4 = self.stage4(x3)  # 119

        y1 = self.up1[0](x4)
        y1 = ME.cat([y1, x3])
        y1 = self.up1[1](y1)

        y2 = self.up2[0](y1)
        y2 = ME.cat([y2, x2])
        y2 = self.up2[1](y2)

        y3 = self.up3[0](y2)
        y3 = ME.cat([y3, x1])
        y3 = self.up3[1](y3)

        y4 = self.up4[0](y3)
        y4 = ME.cat([y4, x0])
        y4 = self.up4[1](y4)

        out = self.classifier(y4)
        if 'mink' in kwargs and kwargs['mink']:
            return out
        else:
            return out.F
Ejemplo n.º 2
0
    def forward(self, x, skip_features):
        
        out = self.conv4_tr(x)
        out = self.norm4_tr(out)
        
        out_s4_tr = self.block4_tr(out)

        out = ME.cat(out_s4_tr, skip_features[-1])

        out = self.conv3_tr(out)
        out = self.norm3_tr(out)
        out_s2_tr = self.block3_tr(out)
        
        out = ME.cat(out_s2_tr, skip_features[-2])

        out = self.conv2_tr(out)
        out = self.norm2_tr(out)
        out_s1_tr = self.block2_tr(out)
        
        out = ME.cat(out_s1_tr, skip_features[-3])

        out = self.conv1_tr(out)
        out = MEF.relu(out)
        out = self.final(out)

        return out
Ejemplo n.º 3
0
    def forward(self, x, **kargs):
        x0 = self.stem(x)
        x1 = self.stage1(x0)
        x2 = self.stage2(x1)
        x3 = self.stage3(x2)
        x4 = self.stage4(x3)

        y1 = self.up1[0](x4)
        y1 = ME.cat([y1, x3])
        y1 = self.up1[1](y1)

        y2 = self.up2[0](y1)
        y2 = ME.cat([y2, x2])
        y2 = self.up2[1](y2)

        y3 = self.up3[0](y2)
        y3 = ME.cat([y3, x1])
        y3 = self.up3[1](y3)

        y4 = self.up4[0](y3)
        y4 = ME.cat([y4, x0])
        y4 = self.up4[1](y4)

        out = self.classifier(y4)
        return out
Ejemplo n.º 4
0
    def forward(self, x: ME.TensorField):
        x = self.mlp1(x)
        y = x.sparse()

        y = self.conv1(y)
        y1 = self.pool(y)

        y = self.conv2(y1)
        y2 = self.pool(y)

        y = self.conv3(y2)
        y3 = self.pool(y)

        y = self.conv4(y3)
        y4 = self.pool(y)

        x1 = y1.slice(x)
        x2 = y2.slice(x)
        x3 = y3.slice(x)
        x4 = y4.slice(x)

        x = ME.cat(x1, x2, x3, x4)

        y = self.conv5(x.sparse())
        x1 = self.global_max_pool(y)
        x2 = self.global_avg_pool(y)

        return self.final(ME.cat(x1, x2)).F
Ejemplo n.º 5
0
    def forward(self, x):
        out_s1 = self.conv1(x)
        out_s1 = self.norm1(out_s1)
        out_s1 = self.block1(out_s1)
        out = MEF.relu(out_s1)

        out_s2 = self.conv2(out)
        out_s2 = self.norm2(out_s2)
        out_s2 = self.block2(out_s2)
        out = MEF.relu(out_s2)

        out_s4 = self.conv3(out)
        out_s4 = self.norm3(out_s4)
        out_s4 = self.block3(out_s4)
        out = MEF.relu(out_s4)

        out_s8 = self.conv4(out)
        out_s8 = self.norm4(out_s8)
        out_s8 = self.block4(out_s8)
        out = MEF.relu(out_s8)

        out = self.conv4_tr(out)
        out = self.norm4_tr(out)
        out = self.block4_tr(out)
        out_s4_tr = MEF.relu(out)

        out = ME.cat(out_s4_tr, out_s4)

        out = self.conv3_tr(out)
        out = self.norm3_tr(out)
        out = self.block3_tr(out)
        out_s2_tr = MEF.relu(out)

        out = ME.cat(out_s2_tr, out_s2)

        out = self.conv2_tr(out)
        out = self.norm2_tr(out)
        out = self.block2_tr(out)
        out_s1_tr = MEF.relu(out)

        out = ME.cat(out_s1_tr, out_s1)
        out_feat = self.conv1_tr(out)
        out_feat = MEF.relu(out_feat)
        out_feat = self.final(out_feat)

        out_att = self.conv1_tr_att(out)
        out_att = MEF.relu(out_att)
        out_att = self.final_att(out_att)
        out_att = ME.SparseTensor(self.final_att_act(out_att.F),
                                  coords_key=out_att.coords_key,
                                  coords_manager=out_att.coords_man)

        if self.normalize_feature:
            out_feat = ME.SparseTensor(
                out_feat.F / torch.norm(out_feat.F, p=2, dim=1, keepdim=True),
                coords_key=out_feat.coords_key,
                coords_manager=out_feat.coords_man)

        return dict(features=out_feat, attention=out_att)
Ejemplo n.º 6
0
    def forward(self, x):
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)

        # tensor_stride=16
        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        out = self.block4(out)

        # tensor_stride=8
        out = self.convtr4p16s2(out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = ME.cat(out, out_b3p8)
        out = self.block5(out)

        # tensor_stride=4
        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = ME.cat(out, out_b2p4)
        out = self.block6(out)

        # tensor_stride=2
        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out = ME.cat(out, out_b1p2)
        out = self.block7(out)

        # tensor_stride=1
        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = self.relu(out)

        out = ME.cat(out, out_p1)
        out = self.block8(out)

        return self.final(out)
Ejemplo n.º 7
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      if ME.__version__.split(".")[1] == "5":
        return ME.SparseTensor(
            out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
            coordinate_map_key=out.coordinate_map_key,
            coordinate_manager=out.coordinate_manager)
      elif ME.__version__.split(".")[1] == "4":
        return ME.SparseTensor(
            out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
            coords_key=out.coords_key,
            coords_manager=out.coords_man)
    else:
      return out
Ejemplo n.º 8
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat((out_s4_tr, out_s4))

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat((out_s2_tr, out_s2))

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat((out_s1_tr, out_s1))
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      out = ME.SparseTensor(
          out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)

    out = self.fc1(out)
    out = F.relu(self.bn(out))
    out = self.fc2(out)

    return out
Ejemplo n.º 9
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.pool2(out)
    out_s2 = self.conv2(out_s2)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.pool3(out)
    out_s4 = self.conv3(out_s4)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.pool4(out)
    out_s8 = self.conv4(out_s8)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8),
          coordinate_map_key=out.coordinate_map_key,
          coordinate_manager=out.coordinate_manager)
    else:
      return out
Ejemplo n.º 10
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out = MEF.relu(out_s8)

    out_s16 = self.conv5(out)
    out_s16 = self.norm5(out_s16)
    out = MEF.relu(out_s16)

    out = self.conv5_tr(out)
    out = self.norm5_tr(out)
    out_s8_tr = MEF.relu(out)

    out = ME.cat((out_s8_tr, out_s8))

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat((out_s4_tr, out_s4))

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat((out_s2_tr, out_s2))

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat((out_s1_tr, out_s1))
    out = self.conv1_tr(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)
    else:
      return out
Ejemplo n.º 11
0
  def forward(self, x):  # Receptive field size
    out_s1 = self.conv1(x)  # 7
    out_s1 = self.norm1(out_s1)
    out_s1 = MEF.relu(out_s1)
    out_s1 = self.block1(out_s1)

    out_s2 = self.conv2(out_s1)  # 7 + 2 * 2 = 11
    out_s2 = self.norm2(out_s2)
    out_s2 = MEF.relu(out_s2)
    out_s2 = self.block2(out_s2)  # 11 + 2 * (2 + 2) = 19

    out_s4 = self.conv3(out_s2)  # 19 + 4 * 2 = 27
    out_s4 = self.norm3(out_s4)
    out_s4 = MEF.relu(out_s4)
    out_s4 = self.block3(out_s4)  # 27 + 4 * (2 + 2) = 43

    out_s8 = self.conv4(out_s4)  # 43 + 8 * 2 = 59
    out_s8 = self.norm4(out_s8)
    out_s8 = MEF.relu(out_s8)
    out_s8 = self.block4(out_s8)  # 59 + 8 * (2 + 2) = 91

    out = self.conv4_tr(out_s8)  # 91 + 4 * 2 = 99
    out = self.norm4_tr(out)
    out = MEF.relu(out)
    out = self.block4_tr(out)  # 99 + 4 * (2 + 2) = 115

    out = ME.cat(out, out_s4)

    out = self.conv3_tr(out)  # 115 + 2 * 2 = 119
    out = self.norm3_tr(out)
    out = MEF.relu(out)
    out = self.block3_tr(out)  # 119 + 2 * (2 + 2) = 127

    out = ME.cat(out, out_s2)

    out = self.conv2_tr(out)  # 127 + 2 = 129
    out = self.norm2_tr(out)
    out = MEF.relu(out)
    out = self.block2_tr(out)  # 129 + 1 * (2 + 2) = 133

    out = ME.cat(out, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)

    if self.normalize_feature:
      return ME.SparseTensor(
          out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8),
          coords_key=out.coords_key,
          coords_manager=out.coords_man)
    else:
      return out
Ejemplo n.º 12
0
  def forward(self, x):
    out_s1 = self.conv1(x)
    out_s1 = self.norm1(out_s1)
    out_s1 = self.block1(out_s1)
    out = MEF.relu(out_s1)

    out_s2 = self.conv2(out)
    out_s2 = self.norm2(out_s2)
    out_s2 = self.block2(out_s2)
    out = MEF.relu(out_s2)

    out_s4 = self.conv3(out)
    out_s4 = self.norm3(out_s4)
    out_s4 = self.block3(out_s4)
    out = MEF.relu(out_s4)

    out_s8 = self.conv4(out)
    out_s8 = self.norm4(out_s8)
    out_s8 = self.block4(out_s8)
    out = MEF.relu(out_s8)

    out = self.conv4_tr(out)
    out = self.norm4_tr(out)
    out = self.block4_tr(out)
    out_s4_tr = MEF.relu(out)

    out = ME.cat(out_s4_tr, out_s4)

    out = self.conv3_tr(out)
    out = self.norm3_tr(out)
    out = self.block3_tr(out)
    out_s2_tr = MEF.relu(out)

    out = ME.cat(out_s2_tr, out_s2)

    out = self.conv2_tr(out)
    out = self.norm2_tr(out)
    out = self.block2_tr(out)
    out_s1_tr = MEF.relu(out)

    out = ME.cat(out_s1_tr, out_s1)
    out = self.conv1_tr(out)
    out = MEF.relu(out)
    out = self.final(out)
################################################################
    out = torch.max(out.F, 0)[0]
    out = F.relu(self.fc1(out))
#    out = self.fc1(out)
    return out.unsqueeze(0)
Ejemplo n.º 13
0
 def forward(self, x):
   y = self.conv(x)
   if self.inner_module:
     y = self.inner_module(y)
   y = self.convtr(y)
   y = ME.cat(x, y)
   return self.cat_conv(y)
Ejemplo n.º 14
0
    def forward(self, x):
        out0 = self.conv0_1(self.relu(self.conv0_0(x)))
        out1 = self.conv1_2(self.relu(self.conv1_1(self.relu(
            self.conv1_0(x)))))
        out = ME.cat(out0, out1) + x

        return out
Ejemplo n.º 15
0
 def forward(self, x):
     residual = self.residual(x)
     cat = tuple([layer(x) for layer in self.paths])
     out = ME.cat(cat)
     out = self.linear(out)
     out += residual
     return out
Ejemplo n.º 16
0
    def forward(self, x):
        out_s1 = self.block1(x)
        out = MF.relu(out_s1)

        out_s2 = self.block2(out)
        out = MF.relu(out_s2)

        out_s4 = self.block3(out)
        out = MF.relu(out_s4)

        out = MF.relu(self.block3_tr(out))
        out = ME.cat(out, out_s2)

        out = MF.relu(self.block2_tr(out))
        out = ME.cat(out, out_s1)

        return self.conv1_tr(out)
Ejemplo n.º 17
0
    def forward(self, x):
        out_s1 = self.bn1(self.conv1(x))
        out = MF.relu(out_s1)

        out_s2 = self.bn2(self.conv2(out))
        out = MF.relu(out_s2)

        out_s4 = self.bn3(self.conv3(out))
        out = MF.relu(out_s4)

        out = MF.relu(self.bn4(self.conv4(out)))
        out = ME.cat((out, out_s2))

        out = MF.relu(self.bn5(self.conv5(out)))
        out = ME.cat((out, out_s1))

        return self.conv6(out)
Ejemplo n.º 18
0
    def forward(self, x, x_skip):
        residual, x = super().forward(x)

        x = self.upsample(residual) + x

        if self.skip:
            return ME.cat(x, x_skip)
        else:
            return x
Ejemplo n.º 19
0
 def forward(self, x, x_skip=None):
     if x_skip is not None:
         out = ME.cat(x, x_skip)
     else:
         out = x
     out = self.conv_tr(out)
     out = self.bn(out)
     out = self.activation(out)
     out = self.block(out)
     return out
Ejemplo n.º 20
0
 def forward(self, x, x_skip):
     if x_skip is not None:
         x = ME.cat(x, x_skip)
     out_s = self.conv(x)
     if self.final is None:
         out_s = self.norm(out_s)
         out = self.block(out_s)
         return out
     else:
         out_s = MEF.relu(out_s)
         out = self.final(out_s)
         return out
Ejemplo n.º 21
0
    def forward(self, input):

        cat = []
        for i, pool in enumerate(self.pool):
            x = pool(input)
            # First item is Global Pooling
            if i == 0:
                x = self.unpool[i](input, x)
            else:
                x = self.unpool[i](x)
            cat.append(x)
        out = ME.cat(cat)
        out = self.linear(out)

        return out
Ejemplo n.º 22
0
def sparse_tensor_arithmetics():
    coords0, feats0 = to_sparse_coo(data_batch_0)
    coords0, feats0 = ME.utils.sparse_collate(coords=[coords0], feats=[feats0])

    coords1, feats1 = to_sparse_coo(data_batch_1)
    coords1, feats1 = ME.utils.sparse_collate(coords=[coords1], feats=[feats1])

    # sparse tensors
    A = ME.SparseTensor(coordinates=coords0, features=feats0)
    B = ME.SparseTensor(coordinates=coords1, features=feats1)

    # The following fails
    try:
        C = A + B
    except AssertionError:
        pass

    B = ME.SparseTensor(
        coordinates=coords1,
        features=feats1,
        coordinate_manager=A.
        coordinate_manager  # must share the same coordinate manager
    )

    C = A + B
    C = A - B
    C = A * B
    C = A / B

    # in place operations
    # Note that it requires the same coords_key (no need to feed coords)
    D = ME.SparseTensor(
        # coords=coords,  not required
        features=feats0,
        coordinate_manager=A.
        coordinate_manager,  # must share the same coordinate manager
        coordinate_map_key=A.
        coordinate_map_key  # For inplace, must share the same coords key
    )

    A += D
    A -= D
    A *= D
    A /= D

    # If you have two or more sparse tensors with the same coords_key, you can concatenate features
    E = ME.cat(A, D)
Ejemplo n.º 23
0
    def forward(self, input):
        coords = input[:, 0:self.D + 1].cpu().int()
        features = input[:, self.D + 1:].float()

        x = ME.SparseTensor(features, coords=coords)
        encoderOutput = self.encoder(x)
        decoderTensors = self.decoder(encoderOutput['finalTensor'],
                                      encoderOutput['encoderTensors'])
        sppFeatures = self.spp(decoderTensors[-1])
        finalFeatures = ME.cat((decoderTensors[-1], sppFeatures))

        res = {
            'encoderTensors': encoderOutput['encoderTensors'],
            'decoderTensors': decoderTensors,
            'finalFeatures': finalFeatures
        }
        return res
Ejemplo n.º 24
0
 def decoder(self, final, encoderTensors):
     '''
     Vanilla UResNet Decoder
     INPUTS:
         - encoderTensors (list of SparseTensor): output of encoder.
     RETURNS:
         - decoderTensors (list of SparseTensor):
         list of feature tensors in decoding path at each spatial resolution.
     '''
     decoderTensors = []
     x = final
     for i, layer in enumerate(self.decoding_conv):
         eTensor = encoderTensors[-i - 2]
         x = layer(x)
         x = ME.cat((eTensor, x))
         x = self.decoding_block[i](x)
         decoderTensors.append(x)
     return decoderTensors
Ejemplo n.º 25
0
    def forward(self, input, lang_feat):
        x1 = self.en1(input)
        x2 = self.en2(x1)
        x3 = self.en3(x2)

        cm = x3.coordinate_manager
        coords = cm.get_coordinates(8)
        batch_coords = coords[:, 0]
        lang_feat_cast = lang_feat[batch_coords.long(), :]
        lang_sparse = ME.SparseTensor(features=lang_feat_cast,
                                      coordinate_map_key=x3.coordinate_map_key,
                                      coordinate_manager=cm)
        x3 = ME.cat(x3, lang_sparse)
        x3 = self.fuse_layer(x3)

        d3 = self.de3(x2, x3)
        d2 = self.de2(x1, d3)
        d1 = self.de1(input, d2)
        net = self.final_layer(d1)
        return net
Ejemplo n.º 26
0
	def forward(self, stensor_src, stensor_tgt):
		################################
		# encode src
		src_s1 = self.conv1(stensor_src)
		src_s1 = self.norm1(src_s1)
		src_s1 = self.block1(src_s1)
		src = MEF.relu(src_s1)

		src_s2 = self.conv2(src)
		src_s2 = self.norm2(src_s2)
		src_s2 = self.block2(src_s2)
		src = MEF.relu(src_s2)

		src_s4 = self.conv3(src)
		src_s4 = self.norm3(src_s4)
		src_s4 = self.block3(src_s4)
		src = MEF.relu(src_s4)

		src_s8 = self.conv4(src)
		src_s8 = self.norm4(src_s8)
		src_s8 = self.block4(src_s8)
		src = MEF.relu(src_s8)


		################################
		# encode tgt
		tgt_s1 = self.conv1(stensor_tgt)
		tgt_s1 = self.norm1(tgt_s1)
		tgt_s1 = self.block1(tgt_s1)
		tgt = MEF.relu(tgt_s1)

		tgt_s2 = self.conv2(tgt)
		tgt_s2 = self.norm2(tgt_s2)
		tgt_s2 = self.block2(tgt_s2)
		tgt = MEF.relu(tgt_s2)

		tgt_s4 = self.conv3(tgt)
		tgt_s4 = self.norm3(tgt_s4)
		tgt_s4 = self.block3(tgt_s4)
		tgt = MEF.relu(tgt_s4)

		tgt_s8 = self.conv4(tgt)
		tgt_s8 = self.norm4(tgt_s8)
		tgt_s8 = self.block4(tgt_s8)
		tgt = MEF.relu(tgt_s8)


		################################
		# overlap attention module
		# empirically, when batch_size = 1, out.C[:,1:] == out.coordinates_at(0)		
		src_feats = src.F.transpose(0,1)[None,:]  #[1, C, N]
		tgt_feats = tgt.F.transpose(0,1)[None,:]  #[1, C, N]
		src_pcd, tgt_pcd = src.C[:,1:] * self.voxel_size, tgt.C[:,1:] * self.voxel_size

		# 1. project the bottleneck feature
		src_feats, tgt_feats = self.bottle(src_feats), self.bottle(tgt_feats)

		# 2. apply GNN to communicate the features and get overlap scores
		src_feats, tgt_feats= self.gnn(src_pcd.transpose(0,1)[None,:], tgt_pcd.transpose(0,1)[None,:],src_feats, tgt_feats)

		src_feats, src_scores = self.proj_gnn(src_feats), self.proj_score(src_feats)[0].transpose(0,1)
		tgt_feats, tgt_scores = self.proj_gnn(tgt_feats), self.proj_score(tgt_feats)[0].transpose(0,1)
		

		# 3. get cross-overlap scores
		src_feats_norm = F.normalize(src_feats, p=2, dim=1)[0].transpose(0,1)
		tgt_feats_norm = F.normalize(tgt_feats, p=2, dim=1)[0].transpose(0,1)
		inner_products = torch.matmul(src_feats_norm, tgt_feats_norm.transpose(0,1))
		temperature = torch.exp(self.epsilon) + 0.03
		src_scores_x = torch.matmul(F.softmax(inner_products / temperature ,dim=1) ,tgt_scores)
		tgt_scores_x = torch.matmul(F.softmax(inner_products.transpose(0,1) / temperature,dim=1),src_scores)

		# 4. update sparse tensor
		src_feats = torch.cat([src_feats[0].transpose(0,1), src_scores, src_scores_x], dim=1)
		tgt_feats = torch.cat([tgt_feats[0].transpose(0,1), tgt_scores, tgt_scores_x], dim=1)
		src = ME.SparseTensor(src_feats, 
			coordinate_map_key=src.coordinate_map_key,
			coordinate_manager=src.coordinate_manager)

		tgt = ME.SparseTensor(tgt_feats,
			coordinate_map_key=tgt.coordinate_map_key,
			coordinate_manager=tgt.coordinate_manager)


		################################
		# decoder src
		src = self.conv4_tr(src)
		src = self.norm4_tr(src)
		src = self.block4_tr(src)
		src_s4_tr = MEF.relu(src)

		src = ME.cat(src_s4_tr, src_s4)

		src = self.conv3_tr(src)
		src = self.norm3_tr(src)
		src = self.block3_tr(src)
		src_s2_tr = MEF.relu(src)

		src = ME.cat(src_s2_tr, src_s2)

		src = self.conv2_tr(src)
		src = self.norm2_tr(src)
		src = self.block2_tr(src)
		src_s1_tr = MEF.relu(src)

		src = ME.cat(src_s1_tr, src_s1)
		src = self.conv1_tr(src)
		src = MEF.relu(src)
		src = self.final(src)

		################################
		# decoder tgt
		tgt = self.conv4_tr(tgt)
		tgt = self.norm4_tr(tgt)
		tgt = self.block4_tr(tgt)
		tgt_s4_tr = MEF.relu(tgt)

		tgt = ME.cat(tgt_s4_tr, tgt_s4)

		tgt = self.conv3_tr(tgt)
		tgt = self.norm3_tr(tgt)
		tgt = self.block3_tr(tgt)
		tgt_s2_tr = MEF.relu(tgt)

		tgt = ME.cat(tgt_s2_tr, tgt_s2)

		tgt = self.conv2_tr(tgt)
		tgt = self.norm2_tr(tgt)
		tgt = self.block2_tr(tgt)
		tgt_s1_tr = MEF.relu(tgt)

		tgt = ME.cat(tgt_s1_tr, tgt_s1)
		tgt = self.conv1_tr(tgt)
		tgt = MEF.relu(tgt)
		tgt = self.final(tgt)

		################################
		# output features and scores
		sigmoid = nn.Sigmoid()
		src_feats, src_overlap, src_saliency = src.F[:,:-2], src.F[:,-2], src.F[:,-1]
		tgt_feats, tgt_overlap, tgt_saliency = tgt.F[:,:-2], tgt.F[:,-2], tgt.F[:,-1]

		src_overlap= torch.clamp(sigmoid(src_overlap.view(-1)),min=0,max=1)
		src_saliency = torch.clamp(sigmoid(src_saliency.view(-1)),min=0,max=1)
		tgt_overlap = torch.clamp(sigmoid(tgt_overlap.view(-1)),min=0,max=1)
		tgt_saliency = torch.clamp(sigmoid(tgt_saliency.view(-1)),min=0,max=1)

		src_feats = F.normalize(src_feats, p=2, dim=1)
		tgt_feats = F.normalize(tgt_feats, p=2, dim=1)

		scores_overlap = torch.cat([src_overlap, tgt_overlap], dim=0)
		scores_saliency = torch.cat([src_saliency, tgt_saliency], dim=0)

		return src_feats,  tgt_feats, scores_overlap, scores_saliency
Ejemplo n.º 27
0
 def forward(self, x, skip):
     if skip is not None:
         inp = ME.cat(x, skip)
     else:
         inp = x
     return super().forward(inp)
 def forward(self, x_a, x_b):
     x_a = self.conv_a(x_a)
     x = ME.cat(x_a, x_b)
     x = self.conv_proj(x)
     return x
Ejemplo n.º 29
0
def cat(*args):
    return ME.cat(*args)
Ejemplo n.º 30
0
    def forward(self, x, **kwargs):
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)

        # tensor_stride=16
        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        out = self.block4(out)

        # tensor_stride=8
        out = self.convtr4p16s2(out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = ME.cat((out, out_b3p8))
        out = self.block5(out)

        # tensor_stride=4
        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = ME.cat((out, out_b2p4))
        out = self.block6(out)

        # tensor_stride=2
        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out = ME.cat((out, out_b1p2))
        out = self.block7(out)

        # tensor_stride=1
        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = self.relu(out)

        out = ME.cat((out, out_p1))
        out = self.block8(out)

        out = self.final(out)

        if 'mink' in kwargs and kwargs['mink']:
            return out
        else:
            return out.F