def forward(self, x: ME.TensorField): x = self.mlp1(x) y = x.sparse() y = self.conv1(y) y1 = self.pool(y) y = self.conv2(y1) y2 = self.pool(y) y = self.conv3(y2) y3 = self.pool(y) y = self.conv4(y3) y4 = self.pool(y) x1 = y1.slice(x) x2 = y2.slice(x) x3 = y3.slice(x) x4 = y4.slice(x) x = ME.cat(x1, x2, x3, x4) y = self.conv5(x.sparse()) x1 = self.global_max_pool(y) x2 = self.global_avg_pool(y) return self.final(ME.cat(x1, x2)).F
def test_field(self): in_channels, D = 2, 2 coords, feats, labels = data_loader(in_channels) feats = feats.double() feats.requires_grad_() input = TensorField(feats, coords) pool = MinkowskiGlobalMaxPooling() output = pool(input) print(output) # Check backward fn = MinkowskiGlobalPoolingFunction() self.assertTrue( gradcheck( fn, ( input.F, pool.pooling_mode, input.coordinate_field_map_key, output.coordinate_map_key, input._manager, ), ) ) if not torch.cuda.is_available(): return input = TensorField(feats, coords, device="cuda") output = pool(input) print(output) # Check backward self.assertTrue( gradcheck( fn, ( input.F, pool.pooling_mode, input.coordinate_field_map_key, output.coordinate_map_key, input._manager, ), ) )
def forward(self, in_field: ME.TensorField): import time start = time.perf_counter() x = in_field.sparse() x = self.input_mlp(x) print('total {} voxels'.format(x.shape[0])) x, attn_0 = self.PTBlock0(x) x = self.TDLayer1(x) x, attn_1 = self.PTBlock1(x) x = self.TDLayer2(x) x, attn_2 = self.PTBlock2(x) x = self.TDLayer3(x) x, attn_3 = self.PTBlock3(x) x = self.TDLayer4(x) # x, attn_4 = self.PTBlock4(x) // at this point it can't find any neighbor with r=10 # x = self.middle_linear(x) # s, attn_middle = self.PTBlock_middle(x) # # x = self.TULayer5(x) # x, attn_5 = self.PTBlock5(x) # # x = self.TULayer6(x) # x, attn_6 = self.PTBlock6(x) # # x = self.TULayer7(x) # x, attn_7 = self.PTBlock7(x) # # x = self.TULayer8(x) # x, attn8 = self.PTBlock8(x) x = self.global_avg_pool(x) x = self.fc(x) # out_field = x.slice(in_field) end = time.time() #print(f"forward time: {end-start} s") # print('PT ratio:{}'.format((pt2 - pt1) / (pt2 - pt0))) return x