class ResNetDown(ME.MinkowskiNetwork): """ Resnet block that looks like in --- strided conv ---- Block ---- sum --[... N times] | | |-- 1x1 - BN --| """ CONVOLUTION = ME.MinkowskiConvolution def __init__(self, down_conv_nn=[], kernel_size=2, dilation=1, dimension=3, stride=2, N=1, block="ResBlock", **kwargs): block = getattr(_res_blocks, block) ME.MinkowskiNetwork.__init__(self, dimension) if stride > 1: conv1_output = down_conv_nn[0] else: conv1_output = down_conv_nn[1] self.conv_in = (Seq().append( self.CONVOLUTION( in_channels=down_conv_nn[0], out_channels=conv1_output, kernel_size=kernel_size, stride=stride, dilation=dilation, bias=False, dimension=dimension, )).append(ME.MinkowskiBatchNorm(conv1_output)).append( ME.MinkowskiReLU())) if N > 0: self.blocks = Seq() for _ in range(N): self.blocks.append( block(conv1_output, down_conv_nn[1], self.CONVOLUTION, dimension=dimension)) conv1_output = down_conv_nn[1] else: self.blocks = None def forward(self, x): out = self.conv_in(x) if self.blocks: out = self.blocks(out) return out
class ResNetDown(torch.nn.Module): """ Resnet block that looks like in --- strided conv ---- Block ---- sum --[... N times] | | |-- 1x1 - BN --| """ CONVOLUTION = "Conv3d" def __init__( self, down_conv_nn=[], kernel_size=2, dilation=1, stride=2, N=1, block="ResBlock", **kwargs, ): block = getattr(_res_blocks, block) super().__init__() if stride > 1: conv1_output = down_conv_nn[0] else: conv1_output = down_conv_nn[1] conv = getattr(snn, self.CONVOLUTION) self.conv_in = (Seq().append( conv( in_channels=down_conv_nn[0], out_channels=conv1_output, kernel_size=kernel_size, stride=stride, dilation=dilation, )).append(snn.BatchNorm(conv1_output)).append(snn.ReLU())) if N > 0: self.blocks = Seq() for _ in range(N): self.blocks.append(block(conv1_output, down_conv_nn[1], conv)) conv1_output = down_conv_nn[1] else: self.blocks = None def forward(self, x): out = self.conv_in(x) if self.blocks: out = self.blocks(out) return out
class MS_SparseConv3d(BaseMS_SparseConv3d): def __init__(self, option, model_type, dataset, modules): # Last Layer BaseMS_SparseConv3d.__init__(self, option, model_type, dataset, modules) option_unet = option.option_unet num_scales = option_unet.num_scales self.unet = nn.ModuleList() for i in range(num_scales): module = UnetMSparseConv3d( option_unet.backbone, input_nc=option_unet.input_nc, grid_size=option_unet.grid_size[i], pointnet_nn=option_unet.pointnet_nn, post_mlp_nn=option_unet.post_mlp_nn, pre_mlp_nn=option_unet.pre_mlp_nn, add_pos=option_unet.add_pos, add_pre_x=option_unet.add_pre_x, aggr=option_unet.aggr, backend=option.backend, ) self.unet.add_module(name=str(i), module=module) # Last MLP layer assert option.mlp_cls is not None last_mlp_opt = option.mlp_cls self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Sequential(*[ Linear( last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), LeakyReLU(0.2), ])) def apply_nn(self, input): # inputs = self.compute_scales(input) outputs = [] for i in range(len(self.unet)): out = self.unet[i](input.clone()) out.x = out.x / (torch.norm(out.x, p=2, dim=1, keepdim=True) + 1e-20) outputs.append(out) x = torch.cat([o.x for o in outputs], 1) out_feat = self.FC_layer(x) if self.normalize_feature: out_feat = out_feat / ( torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20) return out_feat
class MinkowskiFragment(BaseMinkowski, UnwrappedUnetBasedModel): def __init__(self, option, model_type, dataset, modules): UnwrappedUnetBasedModel.__init__(self, option, model_type, dataset, modules) self.mode = option.loss_mode self.normalize_feature = option.normalize_feature self.loss_names = ["loss_reg", "loss"] self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner( getattr(option, "metric_loss", None), getattr(option, "miner", None)) # Last Layer if option.mlp_cls is not None: last_mlp_opt = option.mlp_cls in_feat = last_mlp_opt.nn[0] self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( str(i), Sequential(*[ Linear(in_feat, last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), LeakyReLU(0.2), ]), ) in_feat = last_mlp_opt.nn[i] if last_mlp_opt.dropout: self.FC_layer.append(Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append(Linear(in_feat, in_feat, bias=False)) else: self.FC_layer = torch.nn.Identity() def apply_nn(self, input): x = input stack_down = [] for i in range(len(self.down_modules) - 1): x = self.down_modules[i](x) stack_down.append(x) x = self.down_modules[-1](x) stack_down.append(None) for i in range(len(self.up_modules)): x = self.up_modules[i](x, stack_down.pop()) out_feat = self.FC_layer(x.F) # out_feat = x.F if self.normalize_feature: return out_feat / (torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20) else: return out_feat
class PointGroup(BaseModel): __REQUIRED_DATA__ = [ "pos", ] __REQUIRED_LABELS__ = list(PanopticLabels._fields) def __init__(self, option, model_type, dataset, modules): super(PointGroup, self).__init__(option) backbone_options = option.get("backbone", {"architecture": "unet"}) self.Backbone = Minkowski( backbone_options.architecture, input_nc=dataset.feature_dimension, num_layers=4, config=backbone_options.config, ) self.BackboneHead = Seq().append(FastBatchNorm1d(self.Backbone.output_nc)).append(torch.nn.ReLU()) self._scorer_is_encoder = option.scorer.architecture == "encoder" self._activate_scorer = option.scorer.activate self.Scorer = Minkowski( option.scorer.architecture, input_nc=self.Backbone.output_nc, num_layers=option.scorer.depth ) self.ScorerHead = Seq().append(torch.nn.Linear(self.Scorer.output_nc, 1)).append(torch.nn.Sigmoid()) self.Offset = Seq().append(MLP([self.Backbone.output_nc, self.Backbone.output_nc], bias=False)) self.Offset.append(torch.nn.Linear(self.Backbone.output_nc, 3)) self.Semantic = ( Seq() .append(MLP([self.Backbone.output_nc, self.Backbone.output_nc], bias=False)) .append(torch.nn.Linear(self.Backbone.output_nc, dataset.num_classes)) .append(torch.nn.LogSoftmax()) ) self.loss_names = ["loss", "offset_norm_loss", "offset_dir_loss", "semantic_loss", "score_loss"] stuff_classes = dataset.stuff_classes if is_list(stuff_classes): stuff_classes = torch.Tensor(stuff_classes).long() self._stuff_classes = torch.cat([torch.tensor([IGNORE_LABEL]), stuff_classes]) def set_input(self, data, device): self.raw_pos = data.pos.to(device) self.input = data self.labels = data.y.to(device) all_labels = {l: data[l].to(device) for l in self.__REQUIRED_LABELS__} self.labels = PanopticLabels(**all_labels) def forward(self, epoch=-1, **kwargs): # Backbone backbone_features = self.BackboneHead(self.Backbone(self.input).x) # Semantic and offset heads semantic_logits = self.Semantic(backbone_features) offset_logits = self.Offset(backbone_features) # Grouping and scoring cluster_scores = None all_clusters = None cluster_type = None if epoch == -1 or epoch > self.opt.prepare_epoch: # Active by default all_clusters, cluster_type = self._cluster(semantic_logits, offset_logits) if len(all_clusters): cluster_scores = self._compute_score(all_clusters, backbone_features, semantic_logits) self.output = PanopticResults( semantic_logits=semantic_logits, offset_logits=offset_logits, clusters=all_clusters, cluster_scores=cluster_scores, cluster_type=cluster_type, ) # Sets visual data for debugging with torch.no_grad(): self._dump_visuals(epoch) # Compute loss self._compute_loss() def _cluster(self, semantic_logits, offset_logits): """ Compute clusters from positions and votes """ predicted_labels = torch.max(semantic_logits, 1)[1] clusters_pos = region_grow( self.raw_pos, predicted_labels, self.input.batch.to(self.device), ignore_labels=self._stuff_classes.to(self.device), radius=self.opt.cluster_radius_search, ) clusters_votes = region_grow( self.raw_pos + offset_logits, predicted_labels, self.input.batch.to(self.device), ignore_labels=self._stuff_classes.to(self.device), radius=self.opt.cluster_radius_search, nsample=200, ) all_clusters = clusters_pos + clusters_votes all_clusters = [c.to(self.device) for c in all_clusters] cluster_type = torch.zeros(len(all_clusters), dtype=torch.uint8).to(self.device) cluster_type[len(clusters_pos) :] = 1 return all_clusters, cluster_type def _compute_score(self, all_clusters, backbone_features, semantic_logits): """ Score the clusters """ if self._activate_scorer: x = [] coords = [] batch = [] for i, cluster in enumerate(all_clusters): x.append(backbone_features[cluster]) coords.append(self.input.coords[cluster]) batch.append(i * torch.ones(cluster.shape[0])) batch_cluster = Data(x=torch.cat(x).cpu(), coords=torch.cat(coords).cpu(), batch=torch.cat(batch).cpu(),) score_backbone_out = self.Scorer(batch_cluster) if self._scorer_is_encoder: cluster_feats = score_backbone_out.x else: cluster_feats = scatter( score_backbone_out.x, score_backbone_out.batch.long().to(self.device), dim=0, reduce="max" ) cluster_scores = self.ScorerHead(cluster_feats).squeeze(-1) else: # Use semantic certainty as cluster confidence with torch.no_grad(): cluster_semantic = [] batch = [] for i, cluster in enumerate(all_clusters): cluster_semantic.append(semantic_logits[cluster, :]) batch.append(i * torch.ones(cluster.shape[0])) cluster_semantic = torch.cat(cluster_semantic) batch = torch.cat(batch) cluster_semantic = scatter(cluster_semantic, batch.long().to(self.device), dim=0, reduce="mean") cluster_scores = torch.max(cluster_semantic, 1)[0] return cluster_scores def _compute_loss(self): # Semantic loss self.semantic_loss = torch.nn.functional.nll_loss( self.output.semantic_logits, self.labels.y, ignore_index=IGNORE_LABEL ) self.loss = self.opt.loss_weights.semantic * self.semantic_loss # Offset loss self.input.instance_mask = self.input.instance_mask.to(self.device) self.input.vote_label = self.input.vote_label.to(self.device) offset_losses = offset_loss( self.output.offset_logits[self.input.instance_mask], self.input.vote_label[self.input.instance_mask], torch.sum(self.input.instance_mask), ) for loss_name, loss in offset_losses.items(): setattr(self, loss_name, loss) self.loss += self.opt.loss_weights[loss_name] * loss # Score loss if self.output.cluster_scores is not None and self._activate_scorer: self.score_loss = instance_iou_loss( self.output.clusters, self.output.cluster_scores, self.input.instance_labels.to(self.device), self.input.batch.to(self.device), min_iou_threshold=self.opt.min_iou_threshold, max_iou_threshold=self.opt.max_iou_threshold, ) self.loss += self.score_loss * self.opt.loss_weights["score_loss"] def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" self.loss.backward() def _dump_visuals(self, epoch): if random.random() < self.opt.vizual_ratio: if not hasattr(self, "visual_count"): self.visual_count = 0 data_visual = Data( pos=self.raw_pos, y=self.input.y, instance_labels=self.input.instance_labels, batch=self.input.batch ) data_visual.semantic_pred = torch.max(self.output.semantic_logits, -1)[1] data_visual.vote = self.output.offset_logits nms_idx = self.output.get_instances() if self.output.clusters is not None: data_visual.clusters = [self.output.clusters[i].cpu() for i in nms_idx] data_visual.cluster_type = self.output.cluster_type[nms_idx] if not os.path.exists("viz"): os.mkdir("viz") torch.save(data_visual.to("cpu"), "viz/data_e%i_%i.pt" % (epoch, self.visual_count)) self.visual_count += 1
class BaseMinkowski(BaseModel): def __init__(self, option, model_type, dataset, modules): BaseModel.__init__(self, option) self.mode = option.loss_mode self.normalize_feature = option.normalize_feature self.loss_names = ["loss_reg", "loss"] self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner( getattr(option, "metric_loss", None), getattr(option, "miner", None) ) # Last Layer if option.mlp_cls is not None: last_mlp_opt = option.mlp_cls in_feat = last_mlp_opt.nn[0] self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( str(i), Sequential( *[ Linear(in_feat, last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), LeakyReLU(0.2), ] ), ) in_feat = last_mlp_opt.nn[i] if last_mlp_opt.dropout: self.FC_layer.append(Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append(Linear(in_feat, in_feat, bias=False)) else: self.FC_layer = torch.nn.Identity() def set_input(self, data, device): coords = torch.cat([data.batch.unsqueeze(-1).int(), data.pos.int()], -1) self.input = ME.SparseTensor(data.x, coords=coords).to(device) self.xyz = torch.stack((data.pos_x, data.pos_y, data.pos_z), 0).T.to(device) if hasattr(data, "pos_target"): coords_target = torch.cat([data.batch_target.unsqueeze(-1).int(), data.pos_target.int()], -1) self.input_target = ME.SparseTensor(data.x_target, coords=coords_target).to(device) self.xyz_target = torch.stack((data.pos_x_target, data.pos_y_target, data.pos_z_target), 0).T.to(device) self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device) else: self.match = None def compute_loss_match(self): self.loss_reg = self.metric_loss_module( self.output, self.output_target, self.match[:, :2], self.xyz, self.xyz_target ) self.loss = self.loss_reg def compute_loss_label(self): """ compute the loss separating the miner and the loss each point correspond to a labels """ output = torch.cat([self.output[self.match[:, 0]], self.output_target[self.match[:, 1]]], 0) rang = torch.arange(0, len(self.match), dtype=torch.long, device=self.match.device) labels = torch.cat([rang, rang], 0) hard_pairs = None if self.miner_module is not None: hard_pairs = self.miner_module(output, labels) # loss self.loss_reg = self.metric_loss_module(output, labels, hard_pairs) self.loss = self.loss_reg def apply_nn(self, input): raise NotImplementedError("Model still not defined") def forward(self): self.output = self.apply_nn(self.input) if self.match is None: return self.output self.output_target = self.apply_nn(self.input_target) if self.mode == "match": self.compute_loss_match() elif self.mode == "label": self.compute_loss_label() else: raise NotImplementedError("The mode for the loss is incorrect") return self.output def backward(self): if hasattr(self, "loss"): self.loss.backward() def get_output(self): if self.match is not None: return self.output, self.output_target else: return self.output def get_ind(self): if self.match is not None: return self.match[:, 0], self.match[:, 1], self.size_match else: return None def get_xyz(self): if self.match is not None: return self.xyz, self.xyz_target else: return self.xyz def get_batch(self): if self.match is not None: batch = self.input.C[:, 0] batch_target = self.input_target.C[:, 0] return batch, batch_target else: return None
class MS_SparseConv3d_Shared(BaseMS_SparseConv3d): def __init__(self, option, model_type, dataset, modules): BaseMS_SparseConv3d.__init__(self, option, model_type, dataset, modules) option_unet = option.option_unet self.grid_size = option_unet.grid_size self.unet = UnetMSparseConv3d( option_unet.backbone, input_nc=option_unet.input_nc, pointnet_nn=option_unet.pointnet_nn, post_mlp_nn=option_unet.post_mlp_nn, pre_mlp_nn=option_unet.pre_mlp_nn, add_pos=option_unet.add_pos, add_pre_x=option_unet.add_pre_x, aggr=option_unet.aggr, backend=option.backend, ) assert option.mlp_cls is not None last_mlp_opt = option.mlp_cls self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Sequential(*[ Linear( last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), LeakyReLU(0.2), ])) # Intermediate loss if option.intermediate_loss is not None: int_loss_option = option.intermediate_loss self.int_metric_loss, _ = FragmentBaseModel.get_metric_loss_and_miner( getattr(int_loss_option, "metric_loss", None), getattr(int_loss_option, "miner", None)) self.int_weights = int_loss_option.weights for i in range(len(int_loss_option.weights)): self.loss_names += ["loss_intermediate_loss_{}".format(i)] else: self.int_metric_loss = None def compute_intermediate_loss(self, outputs, outputs_target): assert len(outputs) == len(outputs_target) if self.int_metric_loss is not None: assert len(outputs) == len(self.int_weights) for i, w in enumerate(self.int_weights): xyz = self.input.pos xyz_target = self.input_target.pos loss_i = self.int_metric_loss(outputs[i].x, outputs_target[i].x, self.match[:, :2], xyz, xyz_target) self.loss += w * loss_i setattr(self, "loss_intermediate_loss_{}".format(i), loss_i) def apply_nn(self, input): # inputs = self.compute_scales(input) outputs = [] for i in range(len(self.grid_size)): self.unet.set_grid_size(self.grid_size[i]) out = self.unet(input.clone()) out.x = out.x / (torch.norm(out.x, p=2, dim=1, keepdim=True) + 1e-20) outputs.append(out) x = torch.cat([o.x for o in outputs], 1) out_feat = self.FC_layer(x) if self.normalize_feature: out_feat = out_feat / ( torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20) return out_feat, outputs def forward(self, *args, **kwargs): self.output, outputs = self.apply_nn(self.input) if self.match is None: return self.output self.output_target, outputs_target = self.apply_nn(self.input_target) self.compute_loss() self.compute_intermediate_loss(outputs, outputs_target) return self.output
class MS_SparseConvModel(APIModel): def __init__(self, option, model_type, dataset, modules): BaseModel.__init__(self, option) option_unet = option.option_unet self.normalize_feature = option.normalize_feature self.grid_size = option_unet.grid_size self.unet = UnetMSparseConv3d( option_unet.backbone, input_nc=option_unet.input_nc, pointnet_nn=option_unet.pointnet_nn, post_mlp_nn=option_unet.post_mlp_nn, pre_mlp_nn=option_unet.pre_mlp_nn, add_pos=option_unet.add_pos, add_pre_x=option_unet.add_pre_x, aggr=option_unet.aggr, backend=option.backend, ) if option.mlp_cls is not None: last_mlp_opt = option.mlp_cls self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( nn.Sequential(*[ nn.Linear(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), nn.LeakyReLU(0.2), ])) if last_mlp_opt.dropout: self.FC_layer.append(nn.Dropout(p=last_mlp_opt.dropout)) else: self.FC_layer = torch.nn.Identity() self.head = nn.Sequential( nn.Linear(option.output_nc, dataset.num_classes)) self.loss_names = ["loss_seg"] def apply_nn(self, input): outputs = [] for i in range(len(self.grid_size)): self.unet.set_grid_size(self.grid_size[i]) out = self.unet(input.clone()) out.x = out.x / (torch.norm(out.x, p=2, dim=1, keepdim=True) + 1e-20) outputs.append(out) x = torch.cat([o.x for o in outputs], 1) out_feat = self.FC_layer(x) if self.normalize_feature: out_feat = out_feat / ( torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20) out_feat = self.head(out_feat) return out_feat, outputs def forward(self, *args, **kwargs): logits, _ = self.apply_nn(self.input) self.output = F.log_softmax(logits, dim=-1) if self.labels is not None: self.loss_seg = F.nll_loss(self.output, self.labels, ignore_index=IGNORE_LABEL) def backward(self): self.loss_seg.backward()
class BaseMinkowski(FragmentBaseModel): def __init__(self, option, model_type, dataset, modules): FragmentBaseModel.__init__(self, option) self.mode = option.loss_mode self.normalize_feature = option.normalize_feature self.loss_names = ["loss_reg", "loss"] self.metric_loss_module, self.miner_module = FragmentBaseModel.get_metric_loss_and_miner( getattr(option, "metric_loss", None), getattr(option, "miner", None)) # Last Layer if option.mlp_cls is not None: last_mlp_opt = option.mlp_cls in_feat = last_mlp_opt.nn[0] self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( str(i), Sequential(*[ Linear(in_feat, last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), LeakyReLU(0.2), ]), ) in_feat = last_mlp_opt.nn[i] if last_mlp_opt.dropout: self.FC_layer.append(Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append(Linear(in_feat, in_feat, bias=False)) else: self.FC_layer = torch.nn.Identity() def set_input(self, data, device): coords = torch.cat([data.batch.unsqueeze(-1).int(), data.pos.int()], -1) self.input = ME.SparseTensor(data.x, coords=coords).to(device) self.xyz = torch.stack((data.pos_x, data.pos_y, data.pos_z), 0).T.to(device) if hasattr(data, "pos_target"): coords_target = torch.cat( [data.batch_target.unsqueeze(-1).int(), data.pos_target.int()], -1) self.input_target = ME.SparseTensor( data.x_target, coords=coords_target).to(device) self.xyz_target = torch.stack( (data.pos_x_target, data.pos_y_target, data.pos_z_target), 0).T.to(device) self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device) else: self.match = None def get_batch(self): if self.match is not None: batch = self.input.C[:, 0] batch_target = self.input_target.C[:, 0] return batch, batch_target else: return None, None def get_input(self): if self.match is not None: input = Data(pos=self.xyz, ind=self.match[:, 0], size=self.size_match) input_target = Data(pos=self.xyz_target, ind=self.match[:, 1], size=self.size_match) return input, input_target else: input = Data(pos=self.xyz) return input, None
class PointGroup(BaseModel): __REQUIRED_DATA__ = [ "pos", ] __REQUIRED_LABELS__ = list(PanopticLabels._fields) def __init__(self, option, model_type, dataset, modules): super(PointGroup, self).__init__(option) self.Backbone = Minkowski("unet", input_nc=dataset.feature_dimension, num_layers=4) self._scorer_is_encoder = option.scorer.architecture == "encoder" self.Scorer = Minkowski(option.scorer.architecture, input_nc=self.Backbone.output_nc, num_layers=2) self.ScorerHead = Seq().append( torch.nn.Linear(self.Scorer.output_nc, 1)).append(torch.nn.Sigmoid()) self.Offset = Seq().append( MLP([self.Backbone.output_nc, self.Backbone.output_nc], bias=False)) self.Offset.append(torch.nn.Linear(self.Backbone.output_nc, 3)) self.Semantic = (Seq().append( torch.nn.Linear(self.Backbone.output_nc, dataset.num_classes)).append( torch.nn.LogSoftmax())) self.loss_names = [ "loss", "offset_norm_loss", "offset_dir_loss", "semantic_loss", "score_loss" ] self._stuff_classes = torch.cat( [torch.tensor([IGNORE_LABEL]), dataset.stuff_classes]) def set_input(self, data, device): self.raw_pos = data.pos.to(device) self.input = data self.labels = data.y.to(device) all_labels = {l: data[l].to(device) for l in self.__REQUIRED_LABELS__} self.labels = PanopticLabels(**all_labels) def forward(self, epoch=-1, **kwargs): # Backbone backbone_features = self.Backbone(self.input).x # Semantic and offset heads semantic_logits = self.Semantic(backbone_features) offset_logits = self.Offset(backbone_features) # Grouping and scoring cluster_scores = None all_clusters = None cluster_type = None if epoch == -1 or epoch > self.opt.prepare_epoch: # Active by default predicted_labels = torch.max(semantic_logits, 1)[1] clusters_pos = region_grow( self.raw_pos.cpu(), predicted_labels.cpu(), self.input.batch.cpu(), ignore_labels=self._stuff_classes.cpu(), radius=self.opt.cluster_radius_search, ) clusters_votes = region_grow( self.raw_pos.cpu() + offset_logits.cpu(), predicted_labels.cpu(), self.input.batch.cpu(), ignore_labels=self._stuff_classes.cpu(), radius=self.opt.cluster_radius_search, ) all_clusters = clusters_pos + clusters_votes all_clusters = [c.to(self.device) for c in all_clusters] cluster_type = torch.zeros(len(all_clusters), dtype=torch.uint8).to(self.device) cluster_type[len(clusters_pos):] = 1 if len(all_clusters): x = [] coords = [] batch = [] for i, cluster in enumerate(all_clusters): x.append(backbone_features[cluster]) coords.append(self.input.coords[cluster]) batch.append(i * torch.ones(cluster.shape[0])) batch_cluster = Data(x=torch.cat(x).cpu(), coords=torch.cat(coords).cpu(), batch=torch.cat(batch).cpu()) score_backbone_out = self.Scorer(batch_cluster) if self._scorer_is_encoder: cluster_feats = score_backbone_out.x else: cluster_feats = scatter_max(score_backbone_out.x, score_backbone_out.batch, dim=0) cluster_scores = self.ScorerHead(cluster_feats) self.output = PanopticResults( semantic_logits=semantic_logits, offset_logits=offset_logits, clusters=all_clusters, cluster_scores=cluster_scores, cluster_type=cluster_type, ) # Sets visual data for debugging self._dump_visuals(epoch) # Compute loss self._compute_loss() def _compute_loss(self): # Semantic loss self.semantic_loss = torch.nn.functional.nll_loss( self.output.semantic_logits, self.labels.y, ignore_index=IGNORE_LABEL) self.loss = self.opt.loss_weights.semantic * self.semantic_loss # Offset loss offset_losses = self._offset_loss(self.labels, self.output) # Score loss if self.output.cluster_scores is not None: ious = instance_iou(self.output.clusters, self.labels.instance_labels.to(self.device), self.input.batch.to(self.device)).max(1)[0] lower_mask = ious < self.opt.min_iou_threshold higher_mask = ious > self.opt.max_iou_threshold middle_mask = torch.logical_and(torch.logical_not(lower_mask), torch.logical_not(higher_mask)) assert torch.sum(lower_mask + higher_mask + middle_mask) == ious.shape[0] shat = torch.zeros_like(ious) iou_middle = ious[middle_mask] shat[higher_mask] = 1 shat[middle_mask] = (iou_middle - self.opt.min_iou_threshold) / ( self.opt.max_iou_threshold - self.opt.min_iou_threshold) self.score_loss = torch.nn.functional.binary_cross_entropy( self.output.cluster_scores, shat) self.loss += self.score_loss * self.opt.loss_weights["score_loss"] for loss_name, loss in offset_losses.items(): setattr(self, loss_name, loss) self.loss += self.opt.loss_weights[loss_name] * loss @staticmethod def _offset_loss(data_labels: PanopticLabels, result: PanopticResults): instance_mask = data_labels.instance_mask pt_offsets = result.offset_logits[instance_mask, :] gt_offsets = data_labels.vote_label[instance_mask, :] pt_diff = pt_offsets - gt_offsets pt_dist = torch.sum(torch.abs(pt_diff), dim=-1) offset_norm_loss = torch.sum(pt_dist) / (torch.sum(instance_mask) + 1e-6) gt_offsets_norm = torch.norm(gt_offsets, p=2, dim=1) # (N), float gt_offsets_ = gt_offsets / (gt_offsets_norm.unsqueeze(-1) + 1e-8) pt_offsets_norm = torch.norm(pt_offsets, p=2, dim=1) pt_offsets_ = pt_offsets / (pt_offsets_norm.unsqueeze(-1) + 1e-8) direction_diff = -(gt_offsets_ * pt_offsets_).sum(-1) # (N) offset_dir_loss = torch.sum(direction_diff) / ( torch.sum(instance_mask) + 1e-6) return { "offset_norm_loss": offset_norm_loss, "offset_dir_loss": offset_dir_loss } def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" self.loss.backward() def _dump_visuals(self, epoch): if random.random() < self.opt.vizual_ratio: if not hasattr(self, "visual_count"): self.visual_count = 0 data_visual = Data(pos=self.raw_pos, y=self.input.y, instance_labels=self.input.instance_labels, batch=self.input.batch) data_visual.semantic_pred = torch.max(self.output.semantic_logits, -1)[1] data_visual.vote = self.output.offset_logits if self.output.clusters is not None: data_visual.clusters = [c.cpu() for c in self.output.clusters] data_visual.cluster_type = self.output.cluster_type torch.save(data_visual.to("cpu"), "viz/data_e%i_%i.pt" % (epoch, self.visual_count)) self.visual_count += 1
class SparseConv3D(FragmentBaseModel): def __init__(self, option, model_type, dataset, modules): FragmentBaseModel.__init__(self, option) self.mode = option.loss_mode self.normalize_feature = option.normalize_feature self.loss_names = ["loss_reg", "loss"] self.metric_loss_module, self.miner_module = FragmentBaseModel.get_metric_loss_and_miner( getattr(option, "metric_loss", None), getattr(option, "miner", None)) # Unet self.backbone = SparseConv3d("unet", dataset.feature_dimension, config=option.backbone, backend=option.get( "backend", "minkowski")) # Last Layer if option.mlp_cls is not None: last_mlp_opt = option.mlp_cls in_feat = last_mlp_opt.nn[0] self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( str(i), Sequential(*[ Linear(in_feat, last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), LeakyReLU(0.2), ]), ) in_feat = last_mlp_opt.nn[i] if last_mlp_opt.dropout: self.FC_layer.append(Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append(Linear(in_feat, in_feat, bias=False)) else: self.FC_layer = torch.nn.Identity() def set_input(self, data, device): self.input = Batch(pos=data.pos, x=data.x, batch=data.batch).to(device) if hasattr(data, "pos_target"): self.input_target = Batch(pos=data.pos_target, x=data.x_target, batch=data.batch_target).to(device) self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device) else: self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device) def get_batch(self): if self.match is not None: batch = self.input.batch batch_target = self.input_target.batch return batch, batch_target else: return None, None def get_input(self): if self.match is not None: inp = Data(pos=self.input.pos, ind=self.match[:, 0], size=self.size_match) inp_target = Data(pos=self.input_target.pos, ind=self.match[:, 1], size=self.size_match) return inp, inp_target else: return self.input def apply_nn(self, input): out_feat = self.backbone(input).x out_feat = self.FC_layer(out_feat) if self.normalize_feature: return out_feat / (torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20) else: return out_feat