Esempio n. 1
0
class ResNetDown(ME.MinkowskiNetwork):
    """
    Resnet block that looks like

    in --- strided conv ---- Block ---- sum --[... N times]
                         |              |
                         |-- 1x1 - BN --|
    """

    CONVOLUTION = ME.MinkowskiConvolution

    def __init__(self,
                 down_conv_nn=[],
                 kernel_size=2,
                 dilation=1,
                 dimension=3,
                 stride=2,
                 N=1,
                 block="ResBlock",
                 **kwargs):
        block = getattr(_res_blocks, block)
        ME.MinkowskiNetwork.__init__(self, dimension)
        if stride > 1:
            conv1_output = down_conv_nn[0]
        else:
            conv1_output = down_conv_nn[1]

        self.conv_in = (Seq().append(
            self.CONVOLUTION(
                in_channels=down_conv_nn[0],
                out_channels=conv1_output,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                bias=False,
                dimension=dimension,
            )).append(ME.MinkowskiBatchNorm(conv1_output)).append(
                ME.MinkowskiReLU()))

        if N > 0:
            self.blocks = Seq()
            for _ in range(N):
                self.blocks.append(
                    block(conv1_output,
                          down_conv_nn[1],
                          self.CONVOLUTION,
                          dimension=dimension))
                conv1_output = down_conv_nn[1]
        else:
            self.blocks = None

    def forward(self, x):
        out = self.conv_in(x)
        if self.blocks:
            out = self.blocks(out)
        return out
Esempio n. 2
0
class ResNetDown(torch.nn.Module):
    """
    Resnet block that looks like

    in --- strided conv ---- Block ---- sum --[... N times]
                         |              |
                         |-- 1x1 - BN --|
    """

    CONVOLUTION = "Conv3d"

    def __init__(
        self,
        down_conv_nn=[],
        kernel_size=2,
        dilation=1,
        stride=2,
        N=1,
        block="ResBlock",
        **kwargs,
    ):
        block = getattr(_res_blocks, block)
        super().__init__()
        if stride > 1:
            conv1_output = down_conv_nn[0]
        else:
            conv1_output = down_conv_nn[1]

        conv = getattr(snn, self.CONVOLUTION)
        self.conv_in = (Seq().append(
            conv(
                in_channels=down_conv_nn[0],
                out_channels=conv1_output,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
            )).append(snn.BatchNorm(conv1_output)).append(snn.ReLU()))

        if N > 0:
            self.blocks = Seq()
            for _ in range(N):
                self.blocks.append(block(conv1_output, down_conv_nn[1], conv))
                conv1_output = down_conv_nn[1]
        else:
            self.blocks = None

    def forward(self, x):
        out = self.conv_in(x)
        if self.blocks:
            out = self.blocks(out)
        return out
Esempio n. 3
0
class MS_SparseConv3d(BaseMS_SparseConv3d):
    def __init__(self, option, model_type, dataset, modules):
        # Last Layer
        BaseMS_SparseConv3d.__init__(self, option, model_type, dataset,
                                     modules)
        option_unet = option.option_unet
        num_scales = option_unet.num_scales
        self.unet = nn.ModuleList()
        for i in range(num_scales):
            module = UnetMSparseConv3d(
                option_unet.backbone,
                input_nc=option_unet.input_nc,
                grid_size=option_unet.grid_size[i],
                pointnet_nn=option_unet.pointnet_nn,
                post_mlp_nn=option_unet.post_mlp_nn,
                pre_mlp_nn=option_unet.pre_mlp_nn,
                add_pos=option_unet.add_pos,
                add_pre_x=option_unet.add_pre_x,
                aggr=option_unet.aggr,
                backend=option.backend,
            )
            self.unet.add_module(name=str(i), module=module)
        # Last MLP layer
        assert option.mlp_cls is not None
        last_mlp_opt = option.mlp_cls
        self.FC_layer = Seq()
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Sequential(*[
                    Linear(
                        last_mlp_opt.nn[i -
                                        1], last_mlp_opt.nn[i], bias=False),
                    FastBatchNorm1d(last_mlp_opt.nn[i],
                                    momentum=last_mlp_opt.bn_momentum),
                    LeakyReLU(0.2),
                ]))

    def apply_nn(self, input):
        # inputs = self.compute_scales(input)
        outputs = []
        for i in range(len(self.unet)):
            out = self.unet[i](input.clone())
            out.x = out.x / (torch.norm(out.x, p=2, dim=1, keepdim=True) +
                             1e-20)
            outputs.append(out)
        x = torch.cat([o.x for o in outputs], 1)
        out_feat = self.FC_layer(x)
        if self.normalize_feature:
            out_feat = out_feat / (
                torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20)
        return out_feat
Esempio n. 4
0
class MinkowskiFragment(BaseMinkowski, UnwrappedUnetBasedModel):
    def __init__(self, option, model_type, dataset, modules):
        UnwrappedUnetBasedModel.__init__(self, option, model_type, dataset,
                                         modules)
        self.mode = option.loss_mode
        self.normalize_feature = option.normalize_feature
        self.loss_names = ["loss_reg", "loss"]
        self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner(
            getattr(option, "metric_loss", None),
            getattr(option, "miner", None))
        # Last Layer

        if option.mlp_cls is not None:
            last_mlp_opt = option.mlp_cls
            in_feat = last_mlp_opt.nn[0]
            self.FC_layer = Seq()
            for i in range(1, len(last_mlp_opt.nn)):
                self.FC_layer.append(
                    str(i),
                    Sequential(*[
                        Linear(in_feat, last_mlp_opt.nn[i], bias=False),
                        FastBatchNorm1d(last_mlp_opt.nn[i],
                                        momentum=last_mlp_opt.bn_momentum),
                        LeakyReLU(0.2),
                    ]),
                )
                in_feat = last_mlp_opt.nn[i]

            if last_mlp_opt.dropout:
                self.FC_layer.append(Dropout(p=last_mlp_opt.dropout))

            self.FC_layer.append(Linear(in_feat, in_feat, bias=False))
        else:
            self.FC_layer = torch.nn.Identity()

    def apply_nn(self, input):
        x = input
        stack_down = []
        for i in range(len(self.down_modules) - 1):
            x = self.down_modules[i](x)
            stack_down.append(x)

        x = self.down_modules[-1](x)
        stack_down.append(None)

        for i in range(len(self.up_modules)):
            x = self.up_modules[i](x, stack_down.pop())
        out_feat = self.FC_layer(x.F)
        # out_feat = x.F
        if self.normalize_feature:
            return out_feat / (torch.norm(out_feat, p=2, dim=1, keepdim=True) +
                               1e-20)
        else:
            return out_feat
Esempio n. 5
0
class PointGroup(BaseModel):
    __REQUIRED_DATA__ = [
        "pos",
    ]

    __REQUIRED_LABELS__ = list(PanopticLabels._fields)

    def __init__(self, option, model_type, dataset, modules):
        super(PointGroup, self).__init__(option)
        backbone_options = option.get("backbone", {"architecture": "unet"})
        self.Backbone = Minkowski(
            backbone_options.architecture,
            input_nc=dataset.feature_dimension,
            num_layers=4,
            config=backbone_options.config,
        )
        self.BackboneHead = Seq().append(FastBatchNorm1d(self.Backbone.output_nc)).append(torch.nn.ReLU())

        self._scorer_is_encoder = option.scorer.architecture == "encoder"
        self._activate_scorer = option.scorer.activate
        self.Scorer = Minkowski(
            option.scorer.architecture, input_nc=self.Backbone.output_nc, num_layers=option.scorer.depth
        )
        self.ScorerHead = Seq().append(torch.nn.Linear(self.Scorer.output_nc, 1)).append(torch.nn.Sigmoid())

        self.Offset = Seq().append(MLP([self.Backbone.output_nc, self.Backbone.output_nc], bias=False))
        self.Offset.append(torch.nn.Linear(self.Backbone.output_nc, 3))

        self.Semantic = (
            Seq()
            .append(MLP([self.Backbone.output_nc, self.Backbone.output_nc], bias=False))
            .append(torch.nn.Linear(self.Backbone.output_nc, dataset.num_classes))
            .append(torch.nn.LogSoftmax())
        )
        self.loss_names = ["loss", "offset_norm_loss", "offset_dir_loss", "semantic_loss", "score_loss"]
        stuff_classes = dataset.stuff_classes
        if is_list(stuff_classes):
            stuff_classes = torch.Tensor(stuff_classes).long()
        self._stuff_classes = torch.cat([torch.tensor([IGNORE_LABEL]), stuff_classes])

    def set_input(self, data, device):
        self.raw_pos = data.pos.to(device)
        self.input = data
        self.labels = data.y.to(device)
        all_labels = {l: data[l].to(device) for l in self.__REQUIRED_LABELS__}
        self.labels = PanopticLabels(**all_labels)

    def forward(self, epoch=-1, **kwargs):
        # Backbone
        backbone_features = self.BackboneHead(self.Backbone(self.input).x)

        # Semantic and offset heads
        semantic_logits = self.Semantic(backbone_features)
        offset_logits = self.Offset(backbone_features)

        # Grouping and scoring
        cluster_scores = None
        all_clusters = None
        cluster_type = None
        if epoch == -1 or epoch > self.opt.prepare_epoch:  # Active by default
            all_clusters, cluster_type = self._cluster(semantic_logits, offset_logits)
            if len(all_clusters):
                cluster_scores = self._compute_score(all_clusters, backbone_features, semantic_logits)

        self.output = PanopticResults(
            semantic_logits=semantic_logits,
            offset_logits=offset_logits,
            clusters=all_clusters,
            cluster_scores=cluster_scores,
            cluster_type=cluster_type,
        )

        # Sets visual data for debugging
        with torch.no_grad():
            self._dump_visuals(epoch)

        # Compute loss
        self._compute_loss()

    def _cluster(self, semantic_logits, offset_logits):
        """ Compute clusters from positions and votes """
        predicted_labels = torch.max(semantic_logits, 1)[1]
        clusters_pos = region_grow(
            self.raw_pos,
            predicted_labels,
            self.input.batch.to(self.device),
            ignore_labels=self._stuff_classes.to(self.device),
            radius=self.opt.cluster_radius_search,
        )
        clusters_votes = region_grow(
            self.raw_pos + offset_logits,
            predicted_labels,
            self.input.batch.to(self.device),
            ignore_labels=self._stuff_classes.to(self.device),
            radius=self.opt.cluster_radius_search,
            nsample=200,
        )

        all_clusters = clusters_pos + clusters_votes
        all_clusters = [c.to(self.device) for c in all_clusters]
        cluster_type = torch.zeros(len(all_clusters), dtype=torch.uint8).to(self.device)
        cluster_type[len(clusters_pos) :] = 1
        return all_clusters, cluster_type

    def _compute_score(self, all_clusters, backbone_features, semantic_logits):
        """ Score the clusters """
        if self._activate_scorer:
            x = []
            coords = []
            batch = []
            for i, cluster in enumerate(all_clusters):
                x.append(backbone_features[cluster])
                coords.append(self.input.coords[cluster])
                batch.append(i * torch.ones(cluster.shape[0]))
            batch_cluster = Data(x=torch.cat(x).cpu(), coords=torch.cat(coords).cpu(), batch=torch.cat(batch).cpu(),)
            score_backbone_out = self.Scorer(batch_cluster)
            if self._scorer_is_encoder:
                cluster_feats = score_backbone_out.x
            else:
                cluster_feats = scatter(
                    score_backbone_out.x, score_backbone_out.batch.long().to(self.device), dim=0, reduce="max"
                )
            cluster_scores = self.ScorerHead(cluster_feats).squeeze(-1)
        else:
            # Use semantic certainty as cluster confidence
            with torch.no_grad():
                cluster_semantic = []
                batch = []
                for i, cluster in enumerate(all_clusters):
                    cluster_semantic.append(semantic_logits[cluster, :])
                    batch.append(i * torch.ones(cluster.shape[0]))
                cluster_semantic = torch.cat(cluster_semantic)
                batch = torch.cat(batch)
                cluster_semantic = scatter(cluster_semantic, batch.long().to(self.device), dim=0, reduce="mean")
                cluster_scores = torch.max(cluster_semantic, 1)[0]
        return cluster_scores

    def _compute_loss(self):
        # Semantic loss
        self.semantic_loss = torch.nn.functional.nll_loss(
            self.output.semantic_logits, self.labels.y, ignore_index=IGNORE_LABEL
        )
        self.loss = self.opt.loss_weights.semantic * self.semantic_loss

        # Offset loss
        self.input.instance_mask = self.input.instance_mask.to(self.device)
        self.input.vote_label = self.input.vote_label.to(self.device)
        offset_losses = offset_loss(
            self.output.offset_logits[self.input.instance_mask],
            self.input.vote_label[self.input.instance_mask],
            torch.sum(self.input.instance_mask),
        )
        for loss_name, loss in offset_losses.items():
            setattr(self, loss_name, loss)
            self.loss += self.opt.loss_weights[loss_name] * loss

        # Score loss
        if self.output.cluster_scores is not None and self._activate_scorer:
            self.score_loss = instance_iou_loss(
                self.output.clusters,
                self.output.cluster_scores,
                self.input.instance_labels.to(self.device),
                self.input.batch.to(self.device),
                min_iou_threshold=self.opt.min_iou_threshold,
                max_iou_threshold=self.opt.max_iou_threshold,
            )
            self.loss += self.score_loss * self.opt.loss_weights["score_loss"]

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        self.loss.backward()

    def _dump_visuals(self, epoch):
        if random.random() < self.opt.vizual_ratio:
            if not hasattr(self, "visual_count"):
                self.visual_count = 0
            data_visual = Data(
                pos=self.raw_pos, y=self.input.y, instance_labels=self.input.instance_labels, batch=self.input.batch
            )
            data_visual.semantic_pred = torch.max(self.output.semantic_logits, -1)[1]
            data_visual.vote = self.output.offset_logits
            nms_idx = self.output.get_instances()
            if self.output.clusters is not None:
                data_visual.clusters = [self.output.clusters[i].cpu() for i in nms_idx]
                data_visual.cluster_type = self.output.cluster_type[nms_idx]
            if not os.path.exists("viz"):
                os.mkdir("viz")
            torch.save(data_visual.to("cpu"), "viz/data_e%i_%i.pt" % (epoch, self.visual_count))
            self.visual_count += 1
class BaseMinkowski(BaseModel):
    def __init__(self, option, model_type, dataset, modules):
        BaseModel.__init__(self, option)
        self.mode = option.loss_mode
        self.normalize_feature = option.normalize_feature
        self.loss_names = ["loss_reg", "loss"]
        self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner(
            getattr(option, "metric_loss", None), getattr(option, "miner", None)
        )
        # Last Layer

        if option.mlp_cls is not None:
            last_mlp_opt = option.mlp_cls
            in_feat = last_mlp_opt.nn[0]
            self.FC_layer = Seq()
            for i in range(1, len(last_mlp_opt.nn)):
                self.FC_layer.append(
                    str(i),
                    Sequential(
                        *[
                            Linear(in_feat, last_mlp_opt.nn[i], bias=False),
                            FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum),
                            LeakyReLU(0.2),
                        ]
                    ),
                )
                in_feat = last_mlp_opt.nn[i]

            if last_mlp_opt.dropout:
                self.FC_layer.append(Dropout(p=last_mlp_opt.dropout))

            self.FC_layer.append(Linear(in_feat, in_feat, bias=False))
        else:
            self.FC_layer = torch.nn.Identity()

    def set_input(self, data, device):
        coords = torch.cat([data.batch.unsqueeze(-1).int(), data.pos.int()], -1)
        self.input = ME.SparseTensor(data.x, coords=coords).to(device)
        self.xyz = torch.stack((data.pos_x, data.pos_y, data.pos_z), 0).T.to(device)
        if hasattr(data, "pos_target"):
            coords_target = torch.cat([data.batch_target.unsqueeze(-1).int(), data.pos_target.int()], -1)
            self.input_target = ME.SparseTensor(data.x_target, coords=coords_target).to(device)
            self.xyz_target = torch.stack((data.pos_x_target, data.pos_y_target, data.pos_z_target), 0).T.to(device)
            self.match = data.pair_ind.to(torch.long).to(device)
            self.size_match = data.size_pair_ind.to(torch.long).to(device)
        else:
            self.match = None

    def compute_loss_match(self):
        self.loss_reg = self.metric_loss_module(
            self.output, self.output_target, self.match[:, :2], self.xyz, self.xyz_target
        )
        self.loss = self.loss_reg

    def compute_loss_label(self):
        """
        compute the loss separating the miner and the loss
        each point correspond to a labels
        """
        output = torch.cat([self.output[self.match[:, 0]], self.output_target[self.match[:, 1]]], 0)
        rang = torch.arange(0, len(self.match), dtype=torch.long, device=self.match.device)
        labels = torch.cat([rang, rang], 0)
        hard_pairs = None
        if self.miner_module is not None:
            hard_pairs = self.miner_module(output, labels)
        # loss
        self.loss_reg = self.metric_loss_module(output, labels, hard_pairs)
        self.loss = self.loss_reg

    def apply_nn(self, input):
        raise NotImplementedError("Model still not defined")

    def forward(self):
        self.output = self.apply_nn(self.input)
        if self.match is None:
            return self.output

        self.output_target = self.apply_nn(self.input_target)
        if self.mode == "match":
            self.compute_loss_match()
        elif self.mode == "label":
            self.compute_loss_label()
        else:
            raise NotImplementedError("The mode for the loss is incorrect")

        return self.output

    def backward(self):
        if hasattr(self, "loss"):
            self.loss.backward()

    def get_output(self):
        if self.match is not None:
            return self.output, self.output_target
        else:
            return self.output

    def get_ind(self):
        if self.match is not None:
            return self.match[:, 0], self.match[:, 1], self.size_match
        else:
            return None

    def get_xyz(self):
        if self.match is not None:
            return self.xyz, self.xyz_target
        else:
            return self.xyz

    def get_batch(self):
        if self.match is not None:
            batch = self.input.C[:, 0]
            batch_target = self.input_target.C[:, 0]
            return batch, batch_target
        else:
            return None
Esempio n. 7
0
class MS_SparseConv3d_Shared(BaseMS_SparseConv3d):
    def __init__(self, option, model_type, dataset, modules):
        BaseMS_SparseConv3d.__init__(self, option, model_type, dataset,
                                     modules)
        option_unet = option.option_unet
        self.grid_size = option_unet.grid_size
        self.unet = UnetMSparseConv3d(
            option_unet.backbone,
            input_nc=option_unet.input_nc,
            pointnet_nn=option_unet.pointnet_nn,
            post_mlp_nn=option_unet.post_mlp_nn,
            pre_mlp_nn=option_unet.pre_mlp_nn,
            add_pos=option_unet.add_pos,
            add_pre_x=option_unet.add_pre_x,
            aggr=option_unet.aggr,
            backend=option.backend,
        )
        assert option.mlp_cls is not None
        last_mlp_opt = option.mlp_cls
        self.FC_layer = Seq()
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Sequential(*[
                    Linear(
                        last_mlp_opt.nn[i -
                                        1], last_mlp_opt.nn[i], bias=False),
                    FastBatchNorm1d(last_mlp_opt.nn[i],
                                    momentum=last_mlp_opt.bn_momentum),
                    LeakyReLU(0.2),
                ]))

        # Intermediate loss
        if option.intermediate_loss is not None:
            int_loss_option = option.intermediate_loss
            self.int_metric_loss, _ = FragmentBaseModel.get_metric_loss_and_miner(
                getattr(int_loss_option, "metric_loss", None),
                getattr(int_loss_option, "miner", None))
            self.int_weights = int_loss_option.weights
            for i in range(len(int_loss_option.weights)):
                self.loss_names += ["loss_intermediate_loss_{}".format(i)]
        else:
            self.int_metric_loss = None

    def compute_intermediate_loss(self, outputs, outputs_target):
        assert len(outputs) == len(outputs_target)
        if self.int_metric_loss is not None:
            assert len(outputs) == len(self.int_weights)
            for i, w in enumerate(self.int_weights):
                xyz = self.input.pos
                xyz_target = self.input_target.pos
                loss_i = self.int_metric_loss(outputs[i].x,
                                              outputs_target[i].x,
                                              self.match[:, :2], xyz,
                                              xyz_target)
                self.loss += w * loss_i
                setattr(self, "loss_intermediate_loss_{}".format(i), loss_i)

    def apply_nn(self, input):
        # inputs = self.compute_scales(input)
        outputs = []
        for i in range(len(self.grid_size)):
            self.unet.set_grid_size(self.grid_size[i])
            out = self.unet(input.clone())
            out.x = out.x / (torch.norm(out.x, p=2, dim=1, keepdim=True) +
                             1e-20)
            outputs.append(out)
        x = torch.cat([o.x for o in outputs], 1)
        out_feat = self.FC_layer(x)
        if self.normalize_feature:
            out_feat = out_feat / (
                torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20)
        return out_feat, outputs

    def forward(self, *args, **kwargs):
        self.output, outputs = self.apply_nn(self.input)
        if self.match is None:
            return self.output

        self.output_target, outputs_target = self.apply_nn(self.input_target)
        self.compute_loss()

        self.compute_intermediate_loss(outputs, outputs_target)

        return self.output
Esempio n. 8
0
class MS_SparseConvModel(APIModel):
    def __init__(self, option, model_type, dataset, modules):
        BaseModel.__init__(self, option)
        option_unet = option.option_unet
        self.normalize_feature = option.normalize_feature
        self.grid_size = option_unet.grid_size
        self.unet = UnetMSparseConv3d(
            option_unet.backbone,
            input_nc=option_unet.input_nc,
            pointnet_nn=option_unet.pointnet_nn,
            post_mlp_nn=option_unet.post_mlp_nn,
            pre_mlp_nn=option_unet.pre_mlp_nn,
            add_pos=option_unet.add_pos,
            add_pre_x=option_unet.add_pre_x,
            aggr=option_unet.aggr,
            backend=option.backend,
        )
        if option.mlp_cls is not None:
            last_mlp_opt = option.mlp_cls

            self.FC_layer = Seq()
            for i in range(1, len(last_mlp_opt.nn)):
                self.FC_layer.append(
                    nn.Sequential(*[
                        nn.Linear(last_mlp_opt.nn[i - 1],
                                  last_mlp_opt.nn[i],
                                  bias=False),
                        FastBatchNorm1d(last_mlp_opt.nn[i],
                                        momentum=last_mlp_opt.bn_momentum),
                        nn.LeakyReLU(0.2),
                    ]))
            if last_mlp_opt.dropout:
                self.FC_layer.append(nn.Dropout(p=last_mlp_opt.dropout))
        else:
            self.FC_layer = torch.nn.Identity()
        self.head = nn.Sequential(
            nn.Linear(option.output_nc, dataset.num_classes))
        self.loss_names = ["loss_seg"]

    def apply_nn(self, input):

        outputs = []
        for i in range(len(self.grid_size)):
            self.unet.set_grid_size(self.grid_size[i])
            out = self.unet(input.clone())
            out.x = out.x / (torch.norm(out.x, p=2, dim=1, keepdim=True) +
                             1e-20)
            outputs.append(out)
        x = torch.cat([o.x for o in outputs], 1)
        out_feat = self.FC_layer(x)
        if self.normalize_feature:
            out_feat = out_feat / (
                torch.norm(out_feat, p=2, dim=1, keepdim=True) + 1e-20)
        out_feat = self.head(out_feat)
        return out_feat, outputs

    def forward(self, *args, **kwargs):
        logits, _ = self.apply_nn(self.input)
        self.output = F.log_softmax(logits, dim=-1)
        if self.labels is not None:
            self.loss_seg = F.nll_loss(self.output,
                                       self.labels,
                                       ignore_index=IGNORE_LABEL)

    def backward(self):
        self.loss_seg.backward()
Esempio n. 9
0
class BaseMinkowski(FragmentBaseModel):
    def __init__(self, option, model_type, dataset, modules):
        FragmentBaseModel.__init__(self, option)
        self.mode = option.loss_mode
        self.normalize_feature = option.normalize_feature
        self.loss_names = ["loss_reg", "loss"]
        self.metric_loss_module, self.miner_module = FragmentBaseModel.get_metric_loss_and_miner(
            getattr(option, "metric_loss", None),
            getattr(option, "miner", None))
        # Last Layer

        if option.mlp_cls is not None:
            last_mlp_opt = option.mlp_cls
            in_feat = last_mlp_opt.nn[0]
            self.FC_layer = Seq()
            for i in range(1, len(last_mlp_opt.nn)):
                self.FC_layer.append(
                    str(i),
                    Sequential(*[
                        Linear(in_feat, last_mlp_opt.nn[i], bias=False),
                        FastBatchNorm1d(last_mlp_opt.nn[i],
                                        momentum=last_mlp_opt.bn_momentum),
                        LeakyReLU(0.2),
                    ]),
                )
                in_feat = last_mlp_opt.nn[i]

            if last_mlp_opt.dropout:
                self.FC_layer.append(Dropout(p=last_mlp_opt.dropout))

            self.FC_layer.append(Linear(in_feat, in_feat, bias=False))
        else:
            self.FC_layer = torch.nn.Identity()

    def set_input(self, data, device):
        coords = torch.cat([data.batch.unsqueeze(-1).int(),
                            data.pos.int()], -1)
        self.input = ME.SparseTensor(data.x, coords=coords).to(device)
        self.xyz = torch.stack((data.pos_x, data.pos_y, data.pos_z),
                               0).T.to(device)
        if hasattr(data, "pos_target"):
            coords_target = torch.cat(
                [data.batch_target.unsqueeze(-1).int(),
                 data.pos_target.int()], -1)
            self.input_target = ME.SparseTensor(
                data.x_target, coords=coords_target).to(device)
            self.xyz_target = torch.stack(
                (data.pos_x_target, data.pos_y_target, data.pos_z_target),
                0).T.to(device)
            self.match = data.pair_ind.to(torch.long).to(device)
            self.size_match = data.size_pair_ind.to(torch.long).to(device)
        else:
            self.match = None

    def get_batch(self):
        if self.match is not None:
            batch = self.input.C[:, 0]
            batch_target = self.input_target.C[:, 0]
            return batch, batch_target
        else:
            return None, None

    def get_input(self):
        if self.match is not None:
            input = Data(pos=self.xyz,
                         ind=self.match[:, 0],
                         size=self.size_match)
            input_target = Data(pos=self.xyz_target,
                                ind=self.match[:, 1],
                                size=self.size_match)
            return input, input_target
        else:
            input = Data(pos=self.xyz)
            return input, None
Esempio n. 10
0
class PointGroup(BaseModel):
    __REQUIRED_DATA__ = [
        "pos",
    ]

    __REQUIRED_LABELS__ = list(PanopticLabels._fields)

    def __init__(self, option, model_type, dataset, modules):
        super(PointGroup, self).__init__(option)
        self.Backbone = Minkowski("unet",
                                  input_nc=dataset.feature_dimension,
                                  num_layers=4)

        self._scorer_is_encoder = option.scorer.architecture == "encoder"
        self.Scorer = Minkowski(option.scorer.architecture,
                                input_nc=self.Backbone.output_nc,
                                num_layers=2)
        self.ScorerHead = Seq().append(
            torch.nn.Linear(self.Scorer.output_nc,
                            1)).append(torch.nn.Sigmoid())

        self.Offset = Seq().append(
            MLP([self.Backbone.output_nc, self.Backbone.output_nc],
                bias=False))
        self.Offset.append(torch.nn.Linear(self.Backbone.output_nc, 3))

        self.Semantic = (Seq().append(
            torch.nn.Linear(self.Backbone.output_nc,
                            dataset.num_classes)).append(
                                torch.nn.LogSoftmax()))
        self.loss_names = [
            "loss", "offset_norm_loss", "offset_dir_loss", "semantic_loss",
            "score_loss"
        ]
        self._stuff_classes = torch.cat(
            [torch.tensor([IGNORE_LABEL]), dataset.stuff_classes])

    def set_input(self, data, device):
        self.raw_pos = data.pos.to(device)
        self.input = data
        self.labels = data.y.to(device)
        all_labels = {l: data[l].to(device) for l in self.__REQUIRED_LABELS__}
        self.labels = PanopticLabels(**all_labels)

    def forward(self, epoch=-1, **kwargs):
        # Backbone
        backbone_features = self.Backbone(self.input).x

        # Semantic and offset heads
        semantic_logits = self.Semantic(backbone_features)
        offset_logits = self.Offset(backbone_features)

        # Grouping and scoring
        cluster_scores = None
        all_clusters = None
        cluster_type = None
        if epoch == -1 or epoch > self.opt.prepare_epoch:  # Active by default
            predicted_labels = torch.max(semantic_logits, 1)[1]
            clusters_pos = region_grow(
                self.raw_pos.cpu(),
                predicted_labels.cpu(),
                self.input.batch.cpu(),
                ignore_labels=self._stuff_classes.cpu(),
                radius=self.opt.cluster_radius_search,
            )
            clusters_votes = region_grow(
                self.raw_pos.cpu() + offset_logits.cpu(),
                predicted_labels.cpu(),
                self.input.batch.cpu(),
                ignore_labels=self._stuff_classes.cpu(),
                radius=self.opt.cluster_radius_search,
            )

            all_clusters = clusters_pos + clusters_votes
            all_clusters = [c.to(self.device) for c in all_clusters]
            cluster_type = torch.zeros(len(all_clusters),
                                       dtype=torch.uint8).to(self.device)
            cluster_type[len(clusters_pos):] = 1

            if len(all_clusters):
                x = []
                coords = []
                batch = []
                for i, cluster in enumerate(all_clusters):
                    x.append(backbone_features[cluster])
                    coords.append(self.input.coords[cluster])
                    batch.append(i * torch.ones(cluster.shape[0]))
                batch_cluster = Data(x=torch.cat(x).cpu(),
                                     coords=torch.cat(coords).cpu(),
                                     batch=torch.cat(batch).cpu())
                score_backbone_out = self.Scorer(batch_cluster)
                if self._scorer_is_encoder:
                    cluster_feats = score_backbone_out.x
                else:
                    cluster_feats = scatter_max(score_backbone_out.x,
                                                score_backbone_out.batch,
                                                dim=0)
                cluster_scores = self.ScorerHead(cluster_feats)

        self.output = PanopticResults(
            semantic_logits=semantic_logits,
            offset_logits=offset_logits,
            clusters=all_clusters,
            cluster_scores=cluster_scores,
            cluster_type=cluster_type,
        )

        # Sets visual data for debugging
        self._dump_visuals(epoch)

        # Compute loss
        self._compute_loss()

    def _compute_loss(self):
        # Semantic loss
        self.semantic_loss = torch.nn.functional.nll_loss(
            self.output.semantic_logits,
            self.labels.y,
            ignore_index=IGNORE_LABEL)
        self.loss = self.opt.loss_weights.semantic * self.semantic_loss

        # Offset loss
        offset_losses = self._offset_loss(self.labels, self.output)

        # Score loss
        if self.output.cluster_scores is not None:
            ious = instance_iou(self.output.clusters,
                                self.labels.instance_labels.to(self.device),
                                self.input.batch.to(self.device)).max(1)[0]
            lower_mask = ious < self.opt.min_iou_threshold
            higher_mask = ious > self.opt.max_iou_threshold
            middle_mask = torch.logical_and(torch.logical_not(lower_mask),
                                            torch.logical_not(higher_mask))
            assert torch.sum(lower_mask + higher_mask +
                             middle_mask) == ious.shape[0]
            shat = torch.zeros_like(ious)
            iou_middle = ious[middle_mask]
            shat[higher_mask] = 1
            shat[middle_mask] = (iou_middle - self.opt.min_iou_threshold) / (
                self.opt.max_iou_threshold - self.opt.min_iou_threshold)
            self.score_loss = torch.nn.functional.binary_cross_entropy(
                self.output.cluster_scores, shat)
            self.loss += self.score_loss * self.opt.loss_weights["score_loss"]

        for loss_name, loss in offset_losses.items():
            setattr(self, loss_name, loss)
            self.loss += self.opt.loss_weights[loss_name] * loss

    @staticmethod
    def _offset_loss(data_labels: PanopticLabels, result: PanopticResults):
        instance_mask = data_labels.instance_mask
        pt_offsets = result.offset_logits[instance_mask, :]

        gt_offsets = data_labels.vote_label[instance_mask, :]
        pt_diff = pt_offsets - gt_offsets
        pt_dist = torch.sum(torch.abs(pt_diff), dim=-1)
        offset_norm_loss = torch.sum(pt_dist) / (torch.sum(instance_mask) +
                                                 1e-6)

        gt_offsets_norm = torch.norm(gt_offsets, p=2, dim=1)  # (N), float
        gt_offsets_ = gt_offsets / (gt_offsets_norm.unsqueeze(-1) + 1e-8)
        pt_offsets_norm = torch.norm(pt_offsets, p=2, dim=1)
        pt_offsets_ = pt_offsets / (pt_offsets_norm.unsqueeze(-1) + 1e-8)
        direction_diff = -(gt_offsets_ * pt_offsets_).sum(-1)  # (N)
        offset_dir_loss = torch.sum(direction_diff) / (
            torch.sum(instance_mask) + 1e-6)

        return {
            "offset_norm_loss": offset_norm_loss,
            "offset_dir_loss": offset_dir_loss
        }

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        self.loss.backward()

    def _dump_visuals(self, epoch):
        if random.random() < self.opt.vizual_ratio:
            if not hasattr(self, "visual_count"):
                self.visual_count = 0
            data_visual = Data(pos=self.raw_pos,
                               y=self.input.y,
                               instance_labels=self.input.instance_labels,
                               batch=self.input.batch)
            data_visual.semantic_pred = torch.max(self.output.semantic_logits,
                                                  -1)[1]
            data_visual.vote = self.output.offset_logits
            if self.output.clusters is not None:
                data_visual.clusters = [c.cpu() for c in self.output.clusters]
                data_visual.cluster_type = self.output.cluster_type

            torch.save(data_visual.to("cpu"),
                       "viz/data_e%i_%i.pt" % (epoch, self.visual_count))
            self.visual_count += 1
Esempio n. 11
0
class SparseConv3D(FragmentBaseModel):
    def __init__(self, option, model_type, dataset, modules):
        FragmentBaseModel.__init__(self, option)
        self.mode = option.loss_mode
        self.normalize_feature = option.normalize_feature
        self.loss_names = ["loss_reg", "loss"]
        self.metric_loss_module, self.miner_module = FragmentBaseModel.get_metric_loss_and_miner(
            getattr(option, "metric_loss", None),
            getattr(option, "miner", None))
        # Unet
        self.backbone = SparseConv3d("unet",
                                     dataset.feature_dimension,
                                     config=option.backbone,
                                     backend=option.get(
                                         "backend", "minkowski"))
        # Last Layer
        if option.mlp_cls is not None:
            last_mlp_opt = option.mlp_cls
            in_feat = last_mlp_opt.nn[0]
            self.FC_layer = Seq()
            for i in range(1, len(last_mlp_opt.nn)):
                self.FC_layer.append(
                    str(i),
                    Sequential(*[
                        Linear(in_feat, last_mlp_opt.nn[i], bias=False),
                        FastBatchNorm1d(last_mlp_opt.nn[i],
                                        momentum=last_mlp_opt.bn_momentum),
                        LeakyReLU(0.2),
                    ]),
                )
                in_feat = last_mlp_opt.nn[i]

            if last_mlp_opt.dropout:
                self.FC_layer.append(Dropout(p=last_mlp_opt.dropout))

            self.FC_layer.append(Linear(in_feat, in_feat, bias=False))
        else:
            self.FC_layer = torch.nn.Identity()

    def set_input(self, data, device):
        self.input = Batch(pos=data.pos, x=data.x, batch=data.batch).to(device)
        if hasattr(data, "pos_target"):
            self.input_target = Batch(pos=data.pos_target,
                                      x=data.x_target,
                                      batch=data.batch_target).to(device)
            self.match = data.pair_ind.to(torch.long).to(device)
            self.size_match = data.size_pair_ind.to(torch.long).to(device)
        else:
            self.match = data.pair_ind.to(torch.long).to(device)
            self.size_match = data.size_pair_ind.to(torch.long).to(device)

    def get_batch(self):
        if self.match is not None:
            batch = self.input.batch
            batch_target = self.input_target.batch
            return batch, batch_target
        else:
            return None, None

    def get_input(self):
        if self.match is not None:
            inp = Data(pos=self.input.pos,
                       ind=self.match[:, 0],
                       size=self.size_match)
            inp_target = Data(pos=self.input_target.pos,
                              ind=self.match[:, 1],
                              size=self.size_match)
            return inp, inp_target
        else:
            return self.input

    def apply_nn(self, input):

        out_feat = self.backbone(input).x
        out_feat = self.FC_layer(out_feat)
        if self.normalize_feature:
            return out_feat / (torch.norm(out_feat, p=2, dim=1, keepdim=True) +
                               1e-20)
        else:
            return out_feat