def forward(self, z): # x: SparseTensor z: PointTensor x0 = initial_voxelize(z, self.pres, self.vres) x0 = self.stem(x0) z0 = voxel_to_point(x0, z, nearest=False) z0.F = z0.F x1 = point_to_voxel(x0, z0) x1 = self.stage1(x1) x2 = self.stage2(x1) z1 = voxel_to_point(x2, z0) z1.F = z1.F + self.point_transforms[0](z0.F) y3 = point_to_voxel(x2, z1) if self.dropout: y3.F = self.dropout(y3.F) y3 = self.up1[0](y3) y3 = torchsparse.cat([y3, x1]) y3 = self.up1[1](y3) y4 = self.up2[0](y3) y4 = torchsparse.cat([y4, x0]) y4 = self.up2[1](y4) z3 = voxel_to_point(y4, z1) z3.F = z3.F + self.point_transforms[1](z1.F) return z3.F
def forward(self, x): x0 = self.stem(x) x1 = self.stage1(x0) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) y1 = self.up1[0](x4) y1 = torchsparse.cat([y1, x3]) y1 = self.up1[1](y1) y2 = self.up2[0](y1) y2 = torchsparse.cat([y2, x2]) y2 = self.up2[1](y2) y3 = self.up3[0](y2) y3 = torchsparse.cat([y3, x1]) y3 = self.up3[1](y3) y4 = self.up4[0](y3) y4 = torchsparse.cat([y4, x0]) y4 = self.up4[1](y4) out = self.classifier(y4.F) return out
def forward(self, x): # x: SparseTensor z: PointTensor z = PointTensor(x.F, x.C.float()) #x0 = initial_voxelize(z, self.pres, self.vres) x0 = point_to_voxel(x, z) x0 = self.stem(x0) z0 = voxel_to_point(x0, z) z0.F = z0.F #+ self.point_transforms[0](z.F) x1 = point_to_voxel(x0, z0) x1 = self.downsample[0](x1) x2 = self.downsample[1](x1) x3 = self.downsample[2](x2) x4 = self.downsample[3](x3) # point transform 32 to 256 z1 = voxel_to_point(x4, z0) z1.F = z1.F + self.point_transforms[0](z0.F) y1 = point_to_voxel(x4, z1) y1.F = self.dropout(y1.F) y1 = self.upsample[0].transition(y1) y1 = torchsparse.cat([y1, x3]) y1 = self.upsample[0].feature(y1) #print('y1', y1.C) y2 = self.upsample[1].transition(y1) y2 = torchsparse.cat([y2, x2]) y2 = self.upsample[1].feature(y2) # point transform 256 to 128 z2 = voxel_to_point(y2, z1) z2.F = z2.F + self.point_transforms[1](z1.F) y3 = point_to_voxel(y2, z2) y3.F = self.dropout(y3.F) y3 = self.upsample[2].transition(y3) y3 = torchsparse.cat([y3, x1]) y3 = self.upsample[2].feature(y3) y4 = self.upsample[3].transition(y3) y4 = torchsparse.cat([y4, x0]) y4 = self.upsample[3].feature(y4) z3 = voxel_to_point(y4, z2) z3.F = z3.F + self.point_transforms[2](z2.F) self.classifier.set_in_channel(z3.F.shape[-1]) out = self.classifier(z3.F) return out
def embedder(self, y3, x0, z2): y4 = self.e_up4[0](y3) y4 = torchsparse.cat([y4, x0]) y4 = self.e_up4[1](y4) z3 = voxel_to_point(y4, z2) z3.F = z3.F + self.e_point_transform(z2.F) return self.e_lin(z3.F)
def forward(self, x): # x: SparseTensor z: PointTensor z = PointTensor(x.F, x.C.float()) x0 = initial_voxelize(z, self.hparams.model.pres, self.hparams.model.vres) x0 = self.stem(x0) z0 = voxel_to_point(x0, z, nearest=False) z0.F = z0.F x1 = point_to_voxel(x0, z0) x1 = self.stage1(x1) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) z1 = voxel_to_point(x4, z0) z1.F = z1.F + self.point_transforms[0](z0.F) y1 = point_to_voxel(x4, z1) y1.F = self.dropout(y1.F) y1 = self.up1[0](y1) y1 = torchsparse.cat([y1, x3]) y1 = self.up1[1](y1) y2 = self.up2[0](y1) y2 = torchsparse.cat([y2, x2]) y2 = self.up2[1](y2) z2 = voxel_to_point(y2, z1) z2.F = z2.F + self.point_transforms[1](z1.F) y3 = point_to_voxel(y2, z2) y3.F = self.dropout(y3.F) y3 = self.up3[0](y3) y3 = torchsparse.cat([y3, x1]) y3 = self.up3[1](y3) task = self.hparams.task if task == 'semantic': out = self.classifier(y3, x0, z2) elif task == 'instance': out = self.embedder(y3, x0, z2) elif task == 'panoptic': out = (self.classifier(y3, x0, z2), self.embedder(y3, x0, z2)) else: raise RuntimeError("invalid task!") return out
def forward(self, x): # x: SparseTensor z: PointTensor z = PointTensor(x.F, x.C.float()) x0 = initial_voxelize(z, self.hparams.model.pres, self.hparams.model.vres) x0 = self.stem(x0) z0 = voxel_to_point(x0, z, nearest=False) z0.F = z0.F x1 = point_to_voxel(x0, z0) x1 = self.stage1(x1) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) z1 = voxel_to_point(x4, z0) z1.F = z1.F + self.point_transforms[0](z0.F) y1 = point_to_voxel(x4, z1) y1.F = self.dropout(y1.F) y1 = self.up1[0](y1) y1 = torchsparse.cat([y1, x3]) y1 = self.up1[1](y1) y2 = self.up2[0](y1) y2 = torchsparse.cat([y2, x2]) y2 = self.up2[1](y2) z2 = voxel_to_point(y2, z1) z2.F = z2.F + self.point_transforms[1](z1.F) y3 = point_to_voxel(y2, z2) y3.F = self.dropout(y3.F) y3 = self.up3[0](y3) y3 = torchsparse.cat([y3, x1]) y3 = self.up3[1](y3) task = self.hparams.task out = {} if task == 'semantic' or task == 'panoptic': out['pred_semantic_scores'] = self.classifier(y3, x0, z2) out['pred_semantic_labels'] = out['pred_semantic_scores'].argmax( dim=1) if task == 'instance' or task == 'panoptic': out['pred_offsets'] = self.embedder(y3, x0, z2) return out
def forward(self, x): # x: SparseTensor z: PointTensor z = PointTensor(x.F, x.C.float()) x0 = initial_voxelize(z, 1.0, self.vres) x0 = self.stem(x0) z0 = voxel_to_point(x0, z, nearest=False) z0.F = z0.F x1 = point_to_voxel(x0, z0) x1 = self.stage1(x1) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) z1 = voxel_to_point(x4, z0) z1.F = z1.F + self.point_transforms[0](z0.F) y1 = point_to_voxel(x4, z1) y1.F = self.dropout(y1.F) y1 = self.up1[0](y1) y1 = torchsparse.cat([y1, x3]) y1 = self.up1[1](y1) y2 = self.up2[0](y1) y2 = torchsparse.cat([y2, x2]) y2 = self.up2[1](y2) z2 = voxel_to_point(y2, z1) z2.F = z2.F + self.point_transforms[1](z1.F) y3 = point_to_voxel(y2, z2) y3.F = self.dropout(y3.F) y3 = self.up3[0](y3) y3 = torchsparse.cat([y3, x1]) y3 = self.up3[1](y3) y4 = self.up4[0](y3) y4 = torchsparse.cat([y4, x0]) y4 = self.up4[1](y4) z3 = voxel_to_point(y4, z2) z3.F = z3.F + self.point_transforms[2](z2.F) self.output = self.classifier(z3.F) return self.output
def embedder(self, y3, x0, z2): y4 = self.e_up4[0](y3) y4 = torchsparse.cat([y4, x0]) y4 = self.e_up4[1](y4) z3 = voxel_to_point(y4, z2) z3.F = z3.F + self.e_point_transform(z2.F) out = self.e_lin(z3.F) if 'tanh_scale' in self.hparams.model and self.hparams.model.tanh_scale is not None: out = torch.tanh(out) * self.hparams.model.tanh_scale return out
def cat(*args, dim=1): return TS.cat(args, dim)
def cat(*args): return TS.cat(args)