예제 #1
0
    def __init__(self, option, model_type, dataset, modules):
        # call the initialization method of UnetBasedModel
        UnetBasedModel.__init__(self, option, model_type, dataset, modules)
        self._num_classes = dataset.num_classes
        self._weight_classes = dataset.weight_classes
        self._use_category = getattr(option, "use_category", False)
        if self._use_category:
            if not dataset.class_to_segments:
                raise ValueError(
                    "The dataset needs to specify a class_to_segments property when using category information for segmentation"
                )
            self._num_categories = len(dataset.class_to_segments.keys())
            log.info("Using category information for the predictions with %i categories", self._num_categories)
        else:
            self._num_categories = 0

        # Last MLP
        last_mlp_opt = option.mlp_cls

        self.FC_layer = Seq()
        last_mlp_opt.nn[0] += self._num_categories
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False))
        if last_mlp_opt.dropout:
            self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.append(Conv1D(last_mlp_opt.nn[-1], self._num_classes, activation=None, bias=True, bn=False))
        self.loss_names = ["loss_seg"]

        self.visual_names = ["data_visual"]
예제 #2
0
class PatchSiamese(BackboneBasedModel):
    def __init__(self, option, model_type, dataset, modules):
        """
        Initialize this model class
        Parameters:
            opt -- training/test options
        A few things can be done here.
        - (required) call the initialization function of BaseModel
        - define loss function, visualization images, model names, and optimizers
        """

        BackboneBasedModel.__init__(self, option, model_type, dataset, modules)
        self.set_last_mlp(option.mlp_cls)
        self.loss_names = ["loss_reg"]

    def set_input(self, data, device):
        data = data.to(device)
        self.input = data
        # TODO multiscale data pre_computed...
        if isinstance(data, MultiScaleBatch):
            self.pre_computed = data.multiscale
            del data.multiscale
        else:
            self.pre_computed = None
        # batch siamese
        self.batch_idx = create_batch_siamese(data.pair, data.batch)

    def set_last_mlp(self, last_mlp_opt):
        self.FC_layer = Seq()
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Conv1D(last_mlp_opt.nn[i - 1],
                       last_mlp_opt.nn[i],
                       bn=True,
                       bias=False))

    def set_loss(self):
        raise NotImplementedError("Choose a loss for the metric learning")

    def forward(self) -> Any:
        """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""

        data = self.input
        for i in range(len(self.down_modules)):
            data = self.down_modules[i](data, precomputed=self.pre_computed)

        x = F.relu(self.lin1(data.x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        self.output = self.lin2(x)

        self.loss_reg = self.loss_module(
            self.output) + self.get_internal_loss()
        return self.output

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
        # calculate loss given the input and intermediate results
        self.loss_reg.backward(
        )  # calculate gradients of network G w.r.t. loss_G
예제 #3
0
    def __init__(self, option, model_type, dataset, modules):
        # call the initialization method of UnetBasedModel
        UnetBasedModel.__init__(self, option, model_type, dataset, modules)
        # Last MLP
        self.mode = option.loss_mode
        self.normalize_feature = option.normalize_feature
        self.out_channels = option.out_channels
        self.loss_names = ["loss_reg", "loss"]
        self.metric_loss_module, self.miner_module = UnetBasedModel.get_metric_loss_and_miner(
            getattr(option, "metric_loss", None),
            getattr(option, "miner", None))
        last_mlp_opt = option.mlp_cls

        self.FC_layer = Seq()
        last_mlp_opt.nn[0]
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Conv1D(last_mlp_opt.nn[i - 1],
                       last_mlp_opt.nn[i],
                       bn=True,
                       bias=False))
        if last_mlp_opt.dropout:
            self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.append(
            Conv1D(last_mlp_opt.nn[-1],
                   self.out_channels,
                   activation=None,
                   bias=True,
                   bn=False))
예제 #4
0
 def set_last_mlp(self, last_mlp_opt):
     self.FC_layer = Seq()
     for i in range(1, len(last_mlp_opt.nn)):
         self.FC_layer.append(
             Conv1D(last_mlp_opt.nn[i - 1],
                    last_mlp_opt.nn[i],
                    bn=True,
                    bias=False))
예제 #5
0
    def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs):
        super(RSConvBase, self).__init__(model_config, model_type, dataset, modules)

        default_output_nc = kwargs.get("default_output_nc", 384)
        self._has_mlp_head = False
        self._output_nc = default_output_nc
        if "output_nc" in kwargs:
            self._has_mlp_head = True
            self._output_nc = kwargs["output_nc"]
            self.mlp = Seq()
            self.mlp.append(Conv1D(default_output_nc, self._output_nc, bn=True, bias=False))
예제 #6
0
    def __init__(self, option, model_type, dataset, modules):
        BackboneBasedModel.__init__(self, option, model_type, dataset, modules)

        # Last MLP
        last_mlp_opt = option.mlp_cls
        self._dim_output = last_mlp_opt.nn[-1]

        self.FC_layer = Seq()
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False))

        self.loss_names = ["loss_patch_desc"]
예제 #7
0
    def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs):
        super(BasePointnet2, self).__init__(model_config, model_type, dataset, modules)

        try:
            default_output_nc = extract_output_nc(model_config)
        except:
            default_output_nc = -1
            log.warning("Could not resolve number of output channels")

        self._has_mlp_head = False
        self._output_nc = default_output_nc
        if "output_nc" in kwargs:
            self._has_mlp_head = True
            self._output_nc = kwargs["output_nc"]
            self.mlp = Seq()
            self.mlp.append(Conv1D(default_output_nc, self._output_nc, bn=True, bias=False))
예제 #8
0
class RSConvBase(UnwrappedUnetBasedModel):
    CONV_TYPE = "dense"

    def __init__(self, model_config, model_type, dataset, modules, *args,
                 **kwargs):
        super(RSConvBase, self).__init__(model_config, model_type, dataset,
                                         modules)

        default_output_nc = kwargs.get("default_output_nc", 384)
        self._has_mlp_head = False
        self._output_nc = default_output_nc
        if "output_nc" in kwargs:
            self._has_mlp_head = True
            self._output_nc = kwargs["output_nc"]
            self.mlp = Seq()
            self.mlp.append(
                Conv1D(default_output_nc, self._output_nc, bn=True,
                       bias=False))

    @property
    def has_mlp_head(self):
        return self._has_mlp_head

    @property
    def output_nc(self):
        return self._output_nc

    def _set_input(self, data):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        Sets:
            self.input:
                x -- Features [B, C, N]
                pos -- Points [B, N, 3]
        """
        assert len(data.pos.shape) == 3
        data = data.to(self.device)
        if data.x is not None:
            data.x = data.x.transpose(1, 2).contiguous()
        else:
            data.x = None
        self.input = data
예제 #9
0
class BasePointnet2(UnwrappedUnetBasedModel):

    CONV_TYPE = "dense"

    def __init__(self, model_config, model_type, dataset, modules, *args,
                 **kwargs):
        super(BasePointnet2, self).__init__(model_config, model_type, dataset,
                                            modules)

        try:
            default_output_nc = extract_output_nc(model_config)
        except:
            default_output_nc = -1
            log.warning("Could not resolve number of output channels")

        self._has_mlp_head = False
        self._output_nc = default_output_nc
        if "output_nc" in kwargs:
            self._has_mlp_head = True
            self._output_nc = kwargs["output_nc"]
            self.mlp = Seq()
            self.mlp.append(
                Conv1D(default_output_nc, self._output_nc, bn=True,
                       bias=False))

    @property
    def has_mlp_head(self):
        return self._has_mlp_head

    @property
    def output_nc(self):
        return self._output_nc

    def _set_input(self, data):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        """
        assert len(data.pos.shape) == 3
        data = data.to(self.device)
        if data.x is not None:
            data.x = data.x.transpose(1, 2).contiguous()
        else:
            data.x = None
        self.input = data
예제 #10
0
class SiamesePointNet2_D(BackboneBasedModel):
    r"""
    PointNet2 with multi-scale grouping
    metric learning siamese network that uses feature propogation layers
    """

    def __init__(self, option, model_type, dataset, modules):
        BackboneBasedModel.__init__(self, option, model_type, dataset, modules)

        # Last MLP
        last_mlp_opt = option.mlp_cls
        self._dim_output = last_mlp_opt.nn[-1]

        self.FC_layer = Seq()
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False))

        self.loss_names = ["loss_patch_desc"]

    def set_input(self, data, device):
        assert len(data.pos.shape) == 3
        data = data.to(device)
        self.input = Data(x=data.x.transpose(1, 2).contiguous(), pos=data.pos)

    def forward(self):
        r"""
        forward pass of the network
        """
        data = self.input
        for i in range(len(self.down_modules)):
            data = self.down_modules[i](data)
        last_feature = data.x
        self.output = self.FC_layer(last_feature).transpose(1, 2).contiguous().view((-1, self._dim_output))

        self.loss_reg = self.loss_module(self.output) + self.get_internal_loss()

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
        # calculate loss given the input and intermediate results
        self.loss_reg.backward()  # calculate gradients of network G w.r.t. loss_G
예제 #11
0
 def __init__(self, params):
     super().__init__()
     self._params = params
     self._build_backbone()
     self._model["classifier"] = Seq()
     self._model["classifier"].append(
         Conv1D(self._model_opt.output_nc,
                self._params.data.number,
                activation=None,
                bias=True,
                bn=False))
     print(self._model)
예제 #12
0
 def __init__(self, params, num_classes):
     super(Net, self).__init__()
     self._model = nn.ModuleDict()
     self._model[
         "backbone"], self._model_opt, self._backbone_name = self.build_backbone(
             params)
     self._model["classifier"] = Seq()
     self._model["classifier"].append(
         Conv1D(self._model_opt.output_nc,
                num_classes,
                activation=None,
                bias=True,
                bn=False))
예제 #13
0
    def __init__(self, option, model_type, dataset, modules):

        # Pointnet++ is UnetBased model, call init method of unet model
        UnetBasedModel.__init__(self, option, model_type, dataset, modules)

        self._num_classes = dataset.num_classes
        self._weight_classes = dataset.weight_classes
        self._use_category = getattr(option, "use_category", False)

        if self._use_category:
            if not dataset.class_to_segments:
                raise ValueError("Dataset does not specify needed "
                                 "class_to_segments property")
            self._num_categories = len(dataset.class_to_segments.keys())
            log.info(f"Using category information for "
                     f"the predictions with ${self._num_categories}")
        else:
            self._num_categories = 0
            log.info(f"Category information is not going to be used")

        # ---------------------------------------------------
        # Specification of last MLP based on
        # mlp_cls opt in "mypointnet2" in "pointnet2.yaml"
        last_mlp_opt = copy.deepcopy(option.mlp_cls)

        # A sequential container. Modules will be added to
        # it in the order they are passed in the constructor
        # (Torch classic method)
        self.FC_layer = Seq()
        last_mlp_opt.nn[0] += self._num_categories

        # Adding layers specified in pointnet2.yaml - mlp_cls
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Conv1D(last_mlp_opt.nn[i - 1],
                       last_mlp_opt.nn[i],
                       bn=True,
                       bias=False))

        # Specify dropout of last FC layer (mlp_cls)
        if last_mlp_opt.dropout:
            self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.append(
            Conv1D(last_mlp_opt.nn[-1],
                   self._num_classes,
                   activation=None,
                   bias=True,
                   bn=False))
        # -------------------------------------------------------------------

        # Name specs.
        self.loss_names = ["loss_seg"]
        self.visual_names = ["data_visual"]

        self.input = None
        self.labels = None
        self.batch_idx = None
        self.category = None

        self.loss_seg = None
        self.data_visual = None
예제 #14
0
class MyPointNet2(UnetBasedModel):
    def __init__(self, option, model_type, dataset, modules):

        # Pointnet++ is UnetBased model, call init method of unet model
        UnetBasedModel.__init__(self, option, model_type, dataset, modules)

        self._num_classes = dataset.num_classes
        self._weight_classes = dataset.weight_classes
        self._use_category = getattr(option, "use_category", False)

        if self._use_category:
            if not dataset.class_to_segments:
                raise ValueError("Dataset does not specify needed "
                                 "class_to_segments property")
            self._num_categories = len(dataset.class_to_segments.keys())
            log.info(f"Using category information for "
                     f"the predictions with ${self._num_categories}")
        else:
            self._num_categories = 0
            log.info(f"Category information is not going to be used")

        # ---------------------------------------------------
        # Specification of last MLP based on
        # mlp_cls opt in "mypointnet2" in "pointnet2.yaml"
        last_mlp_opt = copy.deepcopy(option.mlp_cls)

        # A sequential container. Modules will be added to
        # it in the order they are passed in the constructor
        # (Torch classic method)
        self.FC_layer = Seq()
        last_mlp_opt.nn[0] += self._num_categories

        # Adding layers specified in pointnet2.yaml - mlp_cls
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Conv1D(last_mlp_opt.nn[i - 1],
                       last_mlp_opt.nn[i],
                       bn=True,
                       bias=False))

        # Specify dropout of last FC layer (mlp_cls)
        if last_mlp_opt.dropout:
            self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.append(
            Conv1D(last_mlp_opt.nn[-1],
                   self._num_classes,
                   activation=None,
                   bias=True,
                   bn=False))
        # -------------------------------------------------------------------

        # Name specs.
        self.loss_names = ["loss_seg"]
        self.visual_names = ["data_visual"]

        self.input = None
        self.labels = None
        self.batch_idx = None
        self.category = None

        self.loss_seg = None
        self.data_visual = None

    def set_input(self, data, device):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        Sets:
            self.input:
                x -- Features [B, C, N]
                pos -- Points [B, N, 3]
        """

        print("tisk y:", data.y)
        print("tisk pos:", data.pos)

        if len(data.pos.shape) != 3:
            raise ValueError(
                f"Position data shape should be 3, {len(data.pos.shape)} - given!"
            )
        data = data.to(device)

        print("tisk x:", data.x.size())

        x = data.x.transpose(1, 2).contiguous() if (data.x
                                                    is not None) else None

        self.input = Data(x=x, pos=data.pos)
        self.labels = torch.flatten(
            data.y).long() if (data.y is not None) else None  # [B * N]
        self.batch_idx = torch.arange(0, data.pos.shape[0]).view(-1, 1).repeat(
            1, data.pos.shape[1]).view(-1)
        self.category = data.category if self._use_category else ...

    def forward(self, *args, **kwargs):
        r"""
            Forward pass of the network
            self.input:
                x -- Features [B, C, N]
                pos -- Points [B, N, 3]
        """
        data = self.model(self.input)
        last_feature = data.x

        if self._use_category:
            # splitting categorical data to more columns
            cat_one_hot = F.one_hot(self.category,
                                    self._num_categories).float().transpose(
                                        1, 2)
            # concatenates given tensors (dim over which)
            last_feature = torch.cat((last_feature, cat_one_hot), dim=1)

        self.output = self.FC_layer(last_feature).transpose(
            1, 2).contiguous().view((-1, self._num_classes))

        if self._weight_classes is not None:
            self._weight_classes = self._weight_classes.to(self.output.device)

        # Compute loss Cross Entropy
        if self.labels is not None:
            self.loss_seg = F.cross_entropy(self.output,
                                            self.labels,
                                            weight=self._weight_classes,
                                            ignore_index=IGNORE_LABEL)
        self.data_visual = self.input
        self.data_visual.y = torch.reshape(self.labels, data.pos.shape[0:2])
        self.data_visual.pred = torch.max(self.output,
                                          -1)[1].reshape(data.pos.shape[0:2])

        return self.output

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
        # calculate loss given the input and intermediate results
        self.loss_seg.backward()
예제 #15
0
class FragmentPointNet2_D(UnetBasedModel, FragmentBaseModel):
    r"""
        PointNet2 with multi-scale grouping
        descriptors network for registration that uses feature propogation layers

        Parameters
        ----------
        num_classes: int
            Number of semantics classes to predict over -- size of softmax classifier that run for each point
        input_channels: int = 6
            Number of input channels in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        use_xyz: bool = True
            Whether or not to use the xyz position of a point as a feature
    """
    def __init__(self, option, model_type, dataset, modules):
        # call the initialization method of UnetBasedModel
        UnetBasedModel.__init__(self, option, model_type, dataset, modules)
        # Last MLP
        self.mode = option.loss_mode
        self.normalize_feature = option.normalize_feature
        self.out_channels = option.out_channels
        self.loss_names = ["loss_reg", "loss"]
        self.metric_loss_module, self.miner_module = UnetBasedModel.get_metric_loss_and_miner(
            getattr(option, "metric_loss", None),
            getattr(option, "miner", None))
        last_mlp_opt = option.mlp_cls

        self.FC_layer = Seq()
        last_mlp_opt.nn[0]
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Conv1D(last_mlp_opt.nn[i - 1],
                       last_mlp_opt.nn[i],
                       bn=True,
                       bias=False))
        if last_mlp_opt.dropout:
            self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.append(
            Conv1D(last_mlp_opt.nn[-1],
                   self.out_channels,
                   activation=None,
                   bias=True,
                   bn=False))

    def set_input(self, data, device):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        Sets:
            self.input:
                x -- Features [B, C, N]
                pos -- Points [B, N, 3]
        """
        assert len(data.pos.shape) == 3

        if data.x is not None:
            x = data.x.transpose(1, 2).contiguous()
        else:
            x = None
        self.input = Data(x=x, pos=data.pos).to(device)

        if hasattr(data, "pos_target"):
            if data.x_target is not None:
                x = data.x_target.transpose(1, 2).contiguous()
            else:
                x = None
            self.input_target = Data(x=x, pos=data.pos_target).to(device)
            self.match = data.pair_ind.to(torch.long).to(device)
            self.size_match = data.size_pair_ind.to(torch.long).to(device)
        else:
            self.match = None

    def apply_nn(self, input):
        last_feature = self.model(input).x
        output = self.FC_layer(last_feature).transpose(1, 2).contiguous().view(
            (-1, self.out_channels))
        if self.normalize_feature:
            return output / (torch.norm(output, p=2, dim=1, keepdim=True) +
                             1e-5)
        else:
            return output

    def get_input(self):
        if self.match is not None:
            input = Data(pos=self.input.pos.view(-1, 3),
                         ind=self.match[:, 0],
                         size=self.size_match)
            input_target = Data(pos=self.input_target.pos.view(-1, 3),
                                ind=self.match[:, 1],
                                size=self.size_match)
            return input, input_target
        else:
            input = Data(pos=self.input.pos.view(-1, 3))
            return input

    def get_batch(self):
        if self.match is not None:
            batch = (torch.arange(0,
                                  self.input.pos.shape[0]).view(-1, 1).repeat(
                                      1, self.input.pos.shape[1]).view(-1).to(
                                          self.input.pos.device))
            batch_target = (torch.arange(
                0, self.input_target.pos.shape[0]).view(-1, 1).repeat(
                    1, self.input_target.pos.shape[1]).view(-1).to(
                        self.input.pos.device))
            return batch, batch_target
        else:
            return None, None
예제 #16
0
class PointNet2_D(UnetBasedModel):
    r"""
        PointNet2 with multi-scale grouping
        Semantic segmentation network that uses feature propogation layers

        Parameters
        ----------
        num_classes: int
            Number of semantics classes to predict over -- size of softmax classifier that run for each point
        input_channels: int = 6
            Number of input channels in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        use_xyz: bool = True
            Whether or not to use the xyz position of a point as a feature
    """
    def __init__(self, option, model_type, dataset, modules):
        # call the initialization method of UnetBasedModel
        UnetBasedModel.__init__(self, option, model_type, dataset, modules)
        self._num_classes = dataset.num_classes
        self._weight_classes = dataset.weight_classes
        self._use_category = getattr(option, "use_category", False)
        if self._use_category:
            if not dataset.class_to_segments:
                raise ValueError(
                    "The dataset needs to specify a class_to_segments property when using category information for segmentation"
                )
            self._num_categories = len(dataset.class_to_segments.keys())
            log.info(
                "Using category information for the predictions with %i categories",
                self._num_categories)
        else:
            self._num_categories = 0

        # Last MLP
        last_mlp_opt = option.mlp_cls

        self.FC_layer = Seq()
        last_mlp_opt.nn[0] += self._num_categories
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Conv1D(last_mlp_opt.nn[i - 1],
                       last_mlp_opt.nn[i],
                       bn=True,
                       bias=False))
        if last_mlp_opt.dropout:
            self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.append(
            Conv1D(last_mlp_opt.nn[-1],
                   self._num_classes,
                   activation=None,
                   bias=True,
                   bn=False))
        self.loss_names = ["loss_seg"]

        self.visual_names = ["data_visual"]

    def set_input(self, data):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        Sets:
            self.input:
                x -- Features [B, C, N]
                pos -- Points [B, N, 3]
        """
        assert len(data.pos.shape) == 3
        device = self.device
        data = data.to(device)
        if data.x is not None:
            x = data.x.transpose(1, 2).contiguous()
        else:
            x = None
        self.input = Data(x=x, pos=data.pos)
        if data.y is not None:
            self.labels = torch.flatten(data.y).long()  # [B * N]
        else:
            self.labels = None
        self.batch_idx = torch.arange(0, data.pos.shape[0]).view(-1, 1).repeat(
            1, data.pos.shape[1]).view(-1)
        if self._use_category:
            self.category = data.category

    def forward(self, *args, **kwargs):
        r"""
            Forward pass of the network
            self.input:
                x -- Features [B, C, N]
                pos -- Points [B, N, 3]
        """
        self.set_input(kwargs['data'])
        data = self.model(self.input)
        last_feature = data.x
        if self._use_category:
            cat_one_hot = F.one_hot(self.category,
                                    self._num_categories).float().transpose(
                                        1, 2)
            last_feature = torch.cat((last_feature, cat_one_hot), dim=1)

        self.output = self.FC_layer(last_feature).transpose(
            1, 2).contiguous().view((-1, self._num_classes))

        if self._weight_classes is not None:
            self._weight_classes = self._weight_classes.to(self.output.device)
        if self.labels is not None:
            self.loss_seg = F.cross_entropy(self.output,
                                            self.labels,
                                            weight=self._weight_classes,
                                            ignore_index=IGNORE_LABEL)

        self.data_visual = self.input
        self.data_visual.y = torch.reshape(self.labels, data.pos.shape[0:2])
        self.data_visual.pred = torch.max(self.output,
                                          -1)[1].reshape(data.pos.shape[0:2])
        return self.output

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
        # calculate loss given the input and intermediate results
        self.loss_seg.backward()
예제 #17
0
class RSConvLogicModel(UnwrappedUnetBasedModel):
    def __init__(self, option, model_type, dataset, modules):
        # call the initialization method of UnwrappedUnetBasedModel
        UnwrappedUnetBasedModel.__init__(self, option, model_type, dataset,
                                         modules)
        self._num_classes = dataset.num_classes
        self._weight_classes = dataset.weight_classes
        self._use_category = getattr(option, "use_category", False)
        if self._use_category:
            if not dataset.class_to_segments:
                raise ValueError(
                    "The dataset needs to specify a class_to_segments property when using category information for segmentation"
                )
            self._num_categories = len(dataset.class_to_segments.keys())
            log.info(
                "Using category information for the predictions with %i categories",
                self._num_categories)
        else:
            self._num_categories = 0

        # Last MLP
        last_mlp_opt = option.mlp_cls

        self.FC_layer = Seq()
        last_mlp_opt.nn[0] += self._num_categories
        for i in range(1, len(last_mlp_opt.nn)):
            self.FC_layer.append(
                Conv1D(last_mlp_opt.nn[i - 1],
                       last_mlp_opt.nn[i],
                       bn=True,
                       bias=False))
        if last_mlp_opt.dropout:
            self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout))

        self.FC_layer.append(
            Conv1D(last_mlp_opt.nn[-1],
                   self._num_classes,
                   activation=None,
                   bias=True,
                   bn=False))
        self.loss_names = ["loss_seg"]

        self.visual_names = ["data_visual"]

    def set_input(self, data, device):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            input: a dictionary that contains the data itself and its metadata information.
        Sets:
            self.data:
                x -- Features [B, C, N]
                pos -- Features [B, 3, N]
        """
        data = data.to(device)
        if data.x is not None:
            data.x = data.x.transpose(1, 2).contiguous()
        self.input = data
        if data.y is not None:
            self.labels = torch.flatten(data.y).long()  # [B,N]
        else:
            self.labels = data.y
        self.batch_idx = torch.arange(0, data.pos.shape[0]).view(-1, 1).repeat(
            1, data.pos.shape[1]).view(-1)
        if self._use_category:
            self.category = data.category

    def forward(self, *args, **kwargs):
        r"""
            Forward pass of the network
            self.data:
                x -- Features [B, C, N]
                pos -- Features [B, N, 3]
        """
        stack_down = []
        queue_up = queue.Queue()

        data = self.input
        stack_down.append(data)

        for i in range(len(self.down_modules) - 1):
            data = self.down_modules[i](data)
            stack_down.append(data)

        data = self.down_modules[-1](data)
        queue_up.put(data)

        assert len(
            self.inner_modules
        ) == 2, "For this segmentation model, we except 2 distinct inner"
        data_inner = self.inner_modules[0](data)
        data_inner_2 = self.inner_modules[1](stack_down[3])

        for i in range(len(self.up_modules) - 1):
            data = self.up_modules[i]((queue_up.get(), stack_down.pop()))
            queue_up.put(data)

        last_feature = torch.cat([
            data.x,
            data_inner.x.repeat(1, 1, data.x.shape[-1]),
            data_inner_2.x.repeat(1, 1, data.x.shape[-1])
        ],
                                 dim=1)

        if self._use_category:
            cat_one_hot = F.one_hot(self.category,
                                    self._num_categories).float().transpose(
                                        1, 2)
            last_feature = torch.cat((last_feature, cat_one_hot), dim=1)

        self.output = self.FC_layer(last_feature).transpose(
            1, 2).contiguous().view((-1, self._num_classes))

        # Compute loss
        if self._weight_classes is not None:
            self._weight_classes = self._weight_classes.to(self.output.device)

        if self.labels is not None:
            self.loss_seg = F.cross_entropy(self.output,
                                            self.labels,
                                            weight=self._weight_classes)

        self.data_visual = self.input
        self.data_visual.y = torch.reshape(self.labels, data.pos.shape[0:2])
        self.data_visual.pred = torch.max(self.output,
                                          -1)[1].reshape(data.pos.shape[0:2])

        return self.output

    def backward(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
        # calculate loss given the input and intermediate results
        self.loss_seg.backward()