def __init__(self, option, model_type, dataset, modules): # call the initialization method of UnetBasedModel UnetBasedModel.__init__(self, option, model_type, dataset, modules) self._num_classes = dataset.num_classes self._weight_classes = dataset.weight_classes self._use_category = getattr(option, "use_category", False) if self._use_category: if not dataset.class_to_segments: raise ValueError( "The dataset needs to specify a class_to_segments property when using category information for segmentation" ) self._num_categories = len(dataset.class_to_segments.keys()) log.info("Using category information for the predictions with %i categories", self._num_categories) else: self._num_categories = 0 # Last MLP last_mlp_opt = option.mlp_cls self.FC_layer = Seq() last_mlp_opt.nn[0] += self._num_categories for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append(Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) if last_mlp_opt.dropout: self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append(Conv1D(last_mlp_opt.nn[-1], self._num_classes, activation=None, bias=True, bn=False)) self.loss_names = ["loss_seg"] self.visual_names = ["data_visual"]
class PatchSiamese(BackboneBasedModel): def __init__(self, option, model_type, dataset, modules): """ Initialize this model class Parameters: opt -- training/test options A few things can be done here. - (required) call the initialization function of BaseModel - define loss function, visualization images, model names, and optimizers """ BackboneBasedModel.__init__(self, option, model_type, dataset, modules) self.set_last_mlp(option.mlp_cls) self.loss_names = ["loss_reg"] def set_input(self, data, device): data = data.to(device) self.input = data # TODO multiscale data pre_computed... if isinstance(data, MultiScaleBatch): self.pre_computed = data.multiscale del data.multiscale else: self.pre_computed = None # batch siamese self.batch_idx = create_batch_siamese(data.pair, data.batch) def set_last_mlp(self, last_mlp_opt): self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) def set_loss(self): raise NotImplementedError("Choose a loss for the metric learning") def forward(self) -> Any: """Run forward pass. This will be called by both functions <optimize_parameters> and <test>.""" data = self.input for i in range(len(self.down_modules)): data = self.down_modules[i](data, precomputed=self.pre_computed) x = F.relu(self.lin1(data.x)) x = F.dropout(x, p=self.dropout, training=self.training) self.output = self.lin2(x) self.loss_reg = self.loss_module( self.output) + self.get_internal_loss() return self.output def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # caculate the intermediate results if necessary; here self.output has been computed during function <forward> # calculate loss given the input and intermediate results self.loss_reg.backward( ) # calculate gradients of network G w.r.t. loss_G
def __init__(self, option, model_type, dataset, modules): # call the initialization method of UnetBasedModel UnetBasedModel.__init__(self, option, model_type, dataset, modules) # Last MLP self.mode = option.loss_mode self.normalize_feature = option.normalize_feature self.out_channels = option.out_channels self.loss_names = ["loss_reg", "loss"] self.metric_loss_module, self.miner_module = UnetBasedModel.get_metric_loss_and_miner( getattr(option, "metric_loss", None), getattr(option, "miner", None)) last_mlp_opt = option.mlp_cls self.FC_layer = Seq() last_mlp_opt.nn[0] for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) if last_mlp_opt.dropout: self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append( Conv1D(last_mlp_opt.nn[-1], self.out_channels, activation=None, bias=True, bn=False))
def set_last_mlp(self, last_mlp_opt): self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False))
def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): super(RSConvBase, self).__init__(model_config, model_type, dataset, modules) default_output_nc = kwargs.get("default_output_nc", 384) self._has_mlp_head = False self._output_nc = default_output_nc if "output_nc" in kwargs: self._has_mlp_head = True self._output_nc = kwargs["output_nc"] self.mlp = Seq() self.mlp.append(Conv1D(default_output_nc, self._output_nc, bn=True, bias=False))
def __init__(self, option, model_type, dataset, modules): BackboneBasedModel.__init__(self, option, model_type, dataset, modules) # Last MLP last_mlp_opt = option.mlp_cls self._dim_output = last_mlp_opt.nn[-1] self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append(Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) self.loss_names = ["loss_patch_desc"]
def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): super(BasePointnet2, self).__init__(model_config, model_type, dataset, modules) try: default_output_nc = extract_output_nc(model_config) except: default_output_nc = -1 log.warning("Could not resolve number of output channels") self._has_mlp_head = False self._output_nc = default_output_nc if "output_nc" in kwargs: self._has_mlp_head = True self._output_nc = kwargs["output_nc"] self.mlp = Seq() self.mlp.append(Conv1D(default_output_nc, self._output_nc, bn=True, bias=False))
class RSConvBase(UnwrappedUnetBasedModel): CONV_TYPE = "dense" def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): super(RSConvBase, self).__init__(model_config, model_type, dataset, modules) default_output_nc = kwargs.get("default_output_nc", 384) self._has_mlp_head = False self._output_nc = default_output_nc if "output_nc" in kwargs: self._has_mlp_head = True self._output_nc = kwargs["output_nc"] self.mlp = Seq() self.mlp.append( Conv1D(default_output_nc, self._output_nc, bn=True, bias=False)) @property def has_mlp_head(self): return self._has_mlp_head @property def output_nc(self): return self._output_nc def _set_input(self, data): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. Sets: self.input: x -- Features [B, C, N] pos -- Points [B, N, 3] """ assert len(data.pos.shape) == 3 data = data.to(self.device) if data.x is not None: data.x = data.x.transpose(1, 2).contiguous() else: data.x = None self.input = data
class BasePointnet2(UnwrappedUnetBasedModel): CONV_TYPE = "dense" def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): super(BasePointnet2, self).__init__(model_config, model_type, dataset, modules) try: default_output_nc = extract_output_nc(model_config) except: default_output_nc = -1 log.warning("Could not resolve number of output channels") self._has_mlp_head = False self._output_nc = default_output_nc if "output_nc" in kwargs: self._has_mlp_head = True self._output_nc = kwargs["output_nc"] self.mlp = Seq() self.mlp.append( Conv1D(default_output_nc, self._output_nc, bn=True, bias=False)) @property def has_mlp_head(self): return self._has_mlp_head @property def output_nc(self): return self._output_nc def _set_input(self, data): """Unpack input data from the dataloader and perform necessary pre-processing steps. """ assert len(data.pos.shape) == 3 data = data.to(self.device) if data.x is not None: data.x = data.x.transpose(1, 2).contiguous() else: data.x = None self.input = data
class SiamesePointNet2_D(BackboneBasedModel): r""" PointNet2 with multi-scale grouping metric learning siamese network that uses feature propogation layers """ def __init__(self, option, model_type, dataset, modules): BackboneBasedModel.__init__(self, option, model_type, dataset, modules) # Last MLP last_mlp_opt = option.mlp_cls self._dim_output = last_mlp_opt.nn[-1] self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append(Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) self.loss_names = ["loss_patch_desc"] def set_input(self, data, device): assert len(data.pos.shape) == 3 data = data.to(device) self.input = Data(x=data.x.transpose(1, 2).contiguous(), pos=data.pos) def forward(self): r""" forward pass of the network """ data = self.input for i in range(len(self.down_modules)): data = self.down_modules[i](data) last_feature = data.x self.output = self.FC_layer(last_feature).transpose(1, 2).contiguous().view((-1, self._dim_output)) self.loss_reg = self.loss_module(self.output) + self.get_internal_loss() def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # caculate the intermediate results if necessary; here self.output has been computed during function <forward> # calculate loss given the input and intermediate results self.loss_reg.backward() # calculate gradients of network G w.r.t. loss_G
def __init__(self, params): super().__init__() self._params = params self._build_backbone() self._model["classifier"] = Seq() self._model["classifier"].append( Conv1D(self._model_opt.output_nc, self._params.data.number, activation=None, bias=True, bn=False)) print(self._model)
def __init__(self, params, num_classes): super(Net, self).__init__() self._model = nn.ModuleDict() self._model[ "backbone"], self._model_opt, self._backbone_name = self.build_backbone( params) self._model["classifier"] = Seq() self._model["classifier"].append( Conv1D(self._model_opt.output_nc, num_classes, activation=None, bias=True, bn=False))
def __init__(self, option, model_type, dataset, modules): # Pointnet++ is UnetBased model, call init method of unet model UnetBasedModel.__init__(self, option, model_type, dataset, modules) self._num_classes = dataset.num_classes self._weight_classes = dataset.weight_classes self._use_category = getattr(option, "use_category", False) if self._use_category: if not dataset.class_to_segments: raise ValueError("Dataset does not specify needed " "class_to_segments property") self._num_categories = len(dataset.class_to_segments.keys()) log.info(f"Using category information for " f"the predictions with ${self._num_categories}") else: self._num_categories = 0 log.info(f"Category information is not going to be used") # --------------------------------------------------- # Specification of last MLP based on # mlp_cls opt in "mypointnet2" in "pointnet2.yaml" last_mlp_opt = copy.deepcopy(option.mlp_cls) # A sequential container. Modules will be added to # it in the order they are passed in the constructor # (Torch classic method) self.FC_layer = Seq() last_mlp_opt.nn[0] += self._num_categories # Adding layers specified in pointnet2.yaml - mlp_cls for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) # Specify dropout of last FC layer (mlp_cls) if last_mlp_opt.dropout: self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append( Conv1D(last_mlp_opt.nn[-1], self._num_classes, activation=None, bias=True, bn=False)) # ------------------------------------------------------------------- # Name specs. self.loss_names = ["loss_seg"] self.visual_names = ["data_visual"] self.input = None self.labels = None self.batch_idx = None self.category = None self.loss_seg = None self.data_visual = None
class MyPointNet2(UnetBasedModel): def __init__(self, option, model_type, dataset, modules): # Pointnet++ is UnetBased model, call init method of unet model UnetBasedModel.__init__(self, option, model_type, dataset, modules) self._num_classes = dataset.num_classes self._weight_classes = dataset.weight_classes self._use_category = getattr(option, "use_category", False) if self._use_category: if not dataset.class_to_segments: raise ValueError("Dataset does not specify needed " "class_to_segments property") self._num_categories = len(dataset.class_to_segments.keys()) log.info(f"Using category information for " f"the predictions with ${self._num_categories}") else: self._num_categories = 0 log.info(f"Category information is not going to be used") # --------------------------------------------------- # Specification of last MLP based on # mlp_cls opt in "mypointnet2" in "pointnet2.yaml" last_mlp_opt = copy.deepcopy(option.mlp_cls) # A sequential container. Modules will be added to # it in the order they are passed in the constructor # (Torch classic method) self.FC_layer = Seq() last_mlp_opt.nn[0] += self._num_categories # Adding layers specified in pointnet2.yaml - mlp_cls for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) # Specify dropout of last FC layer (mlp_cls) if last_mlp_opt.dropout: self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append( Conv1D(last_mlp_opt.nn[-1], self._num_classes, activation=None, bias=True, bn=False)) # ------------------------------------------------------------------- # Name specs. self.loss_names = ["loss_seg"] self.visual_names = ["data_visual"] self.input = None self.labels = None self.batch_idx = None self.category = None self.loss_seg = None self.data_visual = None def set_input(self, data, device): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. Sets: self.input: x -- Features [B, C, N] pos -- Points [B, N, 3] """ print("tisk y:", data.y) print("tisk pos:", data.pos) if len(data.pos.shape) != 3: raise ValueError( f"Position data shape should be 3, {len(data.pos.shape)} - given!" ) data = data.to(device) print("tisk x:", data.x.size()) x = data.x.transpose(1, 2).contiguous() if (data.x is not None) else None self.input = Data(x=x, pos=data.pos) self.labels = torch.flatten( data.y).long() if (data.y is not None) else None # [B * N] self.batch_idx = torch.arange(0, data.pos.shape[0]).view(-1, 1).repeat( 1, data.pos.shape[1]).view(-1) self.category = data.category if self._use_category else ... def forward(self, *args, **kwargs): r""" Forward pass of the network self.input: x -- Features [B, C, N] pos -- Points [B, N, 3] """ data = self.model(self.input) last_feature = data.x if self._use_category: # splitting categorical data to more columns cat_one_hot = F.one_hot(self.category, self._num_categories).float().transpose( 1, 2) # concatenates given tensors (dim over which) last_feature = torch.cat((last_feature, cat_one_hot), dim=1) self.output = self.FC_layer(last_feature).transpose( 1, 2).contiguous().view((-1, self._num_classes)) if self._weight_classes is not None: self._weight_classes = self._weight_classes.to(self.output.device) # Compute loss Cross Entropy if self.labels is not None: self.loss_seg = F.cross_entropy(self.output, self.labels, weight=self._weight_classes, ignore_index=IGNORE_LABEL) self.data_visual = self.input self.data_visual.y = torch.reshape(self.labels, data.pos.shape[0:2]) self.data_visual.pred = torch.max(self.output, -1)[1].reshape(data.pos.shape[0:2]) return self.output def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # caculate the intermediate results if necessary; here self.output has been computed during function <forward> # calculate loss given the input and intermediate results self.loss_seg.backward()
class FragmentPointNet2_D(UnetBasedModel, FragmentBaseModel): r""" PointNet2 with multi-scale grouping descriptors network for registration that uses feature propogation layers Parameters ---------- num_classes: int Number of semantics classes to predict over -- size of softmax classifier that run for each point input_channels: int = 6 Number of input channels in the feature descriptor for each point. If the point cloud is Nx9, this value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors use_xyz: bool = True Whether or not to use the xyz position of a point as a feature """ def __init__(self, option, model_type, dataset, modules): # call the initialization method of UnetBasedModel UnetBasedModel.__init__(self, option, model_type, dataset, modules) # Last MLP self.mode = option.loss_mode self.normalize_feature = option.normalize_feature self.out_channels = option.out_channels self.loss_names = ["loss_reg", "loss"] self.metric_loss_module, self.miner_module = UnetBasedModel.get_metric_loss_and_miner( getattr(option, "metric_loss", None), getattr(option, "miner", None)) last_mlp_opt = option.mlp_cls self.FC_layer = Seq() last_mlp_opt.nn[0] for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) if last_mlp_opt.dropout: self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append( Conv1D(last_mlp_opt.nn[-1], self.out_channels, activation=None, bias=True, bn=False)) def set_input(self, data, device): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. Sets: self.input: x -- Features [B, C, N] pos -- Points [B, N, 3] """ assert len(data.pos.shape) == 3 if data.x is not None: x = data.x.transpose(1, 2).contiguous() else: x = None self.input = Data(x=x, pos=data.pos).to(device) if hasattr(data, "pos_target"): if data.x_target is not None: x = data.x_target.transpose(1, 2).contiguous() else: x = None self.input_target = Data(x=x, pos=data.pos_target).to(device) self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device) else: self.match = None def apply_nn(self, input): last_feature = self.model(input).x output = self.FC_layer(last_feature).transpose(1, 2).contiguous().view( (-1, self.out_channels)) if self.normalize_feature: return output / (torch.norm(output, p=2, dim=1, keepdim=True) + 1e-5) else: return output def get_input(self): if self.match is not None: input = Data(pos=self.input.pos.view(-1, 3), ind=self.match[:, 0], size=self.size_match) input_target = Data(pos=self.input_target.pos.view(-1, 3), ind=self.match[:, 1], size=self.size_match) return input, input_target else: input = Data(pos=self.input.pos.view(-1, 3)) return input def get_batch(self): if self.match is not None: batch = (torch.arange(0, self.input.pos.shape[0]).view(-1, 1).repeat( 1, self.input.pos.shape[1]).view(-1).to( self.input.pos.device)) batch_target = (torch.arange( 0, self.input_target.pos.shape[0]).view(-1, 1).repeat( 1, self.input_target.pos.shape[1]).view(-1).to( self.input.pos.device)) return batch, batch_target else: return None, None
class PointNet2_D(UnetBasedModel): r""" PointNet2 with multi-scale grouping Semantic segmentation network that uses feature propogation layers Parameters ---------- num_classes: int Number of semantics classes to predict over -- size of softmax classifier that run for each point input_channels: int = 6 Number of input channels in the feature descriptor for each point. If the point cloud is Nx9, this value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors use_xyz: bool = True Whether or not to use the xyz position of a point as a feature """ def __init__(self, option, model_type, dataset, modules): # call the initialization method of UnetBasedModel UnetBasedModel.__init__(self, option, model_type, dataset, modules) self._num_classes = dataset.num_classes self._weight_classes = dataset.weight_classes self._use_category = getattr(option, "use_category", False) if self._use_category: if not dataset.class_to_segments: raise ValueError( "The dataset needs to specify a class_to_segments property when using category information for segmentation" ) self._num_categories = len(dataset.class_to_segments.keys()) log.info( "Using category information for the predictions with %i categories", self._num_categories) else: self._num_categories = 0 # Last MLP last_mlp_opt = option.mlp_cls self.FC_layer = Seq() last_mlp_opt.nn[0] += self._num_categories for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) if last_mlp_opt.dropout: self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append( Conv1D(last_mlp_opt.nn[-1], self._num_classes, activation=None, bias=True, bn=False)) self.loss_names = ["loss_seg"] self.visual_names = ["data_visual"] def set_input(self, data): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. Sets: self.input: x -- Features [B, C, N] pos -- Points [B, N, 3] """ assert len(data.pos.shape) == 3 device = self.device data = data.to(device) if data.x is not None: x = data.x.transpose(1, 2).contiguous() else: x = None self.input = Data(x=x, pos=data.pos) if data.y is not None: self.labels = torch.flatten(data.y).long() # [B * N] else: self.labels = None self.batch_idx = torch.arange(0, data.pos.shape[0]).view(-1, 1).repeat( 1, data.pos.shape[1]).view(-1) if self._use_category: self.category = data.category def forward(self, *args, **kwargs): r""" Forward pass of the network self.input: x -- Features [B, C, N] pos -- Points [B, N, 3] """ self.set_input(kwargs['data']) data = self.model(self.input) last_feature = data.x if self._use_category: cat_one_hot = F.one_hot(self.category, self._num_categories).float().transpose( 1, 2) last_feature = torch.cat((last_feature, cat_one_hot), dim=1) self.output = self.FC_layer(last_feature).transpose( 1, 2).contiguous().view((-1, self._num_classes)) if self._weight_classes is not None: self._weight_classes = self._weight_classes.to(self.output.device) if self.labels is not None: self.loss_seg = F.cross_entropy(self.output, self.labels, weight=self._weight_classes, ignore_index=IGNORE_LABEL) self.data_visual = self.input self.data_visual.y = torch.reshape(self.labels, data.pos.shape[0:2]) self.data_visual.pred = torch.max(self.output, -1)[1].reshape(data.pos.shape[0:2]) return self.output def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # caculate the intermediate results if necessary; here self.output has been computed during function <forward> # calculate loss given the input and intermediate results self.loss_seg.backward()
class RSConvLogicModel(UnwrappedUnetBasedModel): def __init__(self, option, model_type, dataset, modules): # call the initialization method of UnwrappedUnetBasedModel UnwrappedUnetBasedModel.__init__(self, option, model_type, dataset, modules) self._num_classes = dataset.num_classes self._weight_classes = dataset.weight_classes self._use_category = getattr(option, "use_category", False) if self._use_category: if not dataset.class_to_segments: raise ValueError( "The dataset needs to specify a class_to_segments property when using category information for segmentation" ) self._num_categories = len(dataset.class_to_segments.keys()) log.info( "Using category information for the predictions with %i categories", self._num_categories) else: self._num_categories = 0 # Last MLP last_mlp_opt = option.mlp_cls self.FC_layer = Seq() last_mlp_opt.nn[0] += self._num_categories for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( Conv1D(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bn=True, bias=False)) if last_mlp_opt.dropout: self.FC_layer.append(torch.nn.Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append( Conv1D(last_mlp_opt.nn[-1], self._num_classes, activation=None, bias=True, bn=False)) self.loss_names = ["loss_seg"] self.visual_names = ["data_visual"] def set_input(self, data, device): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. Sets: self.data: x -- Features [B, C, N] pos -- Features [B, 3, N] """ data = data.to(device) if data.x is not None: data.x = data.x.transpose(1, 2).contiguous() self.input = data if data.y is not None: self.labels = torch.flatten(data.y).long() # [B,N] else: self.labels = data.y self.batch_idx = torch.arange(0, data.pos.shape[0]).view(-1, 1).repeat( 1, data.pos.shape[1]).view(-1) if self._use_category: self.category = data.category def forward(self, *args, **kwargs): r""" Forward pass of the network self.data: x -- Features [B, C, N] pos -- Features [B, N, 3] """ stack_down = [] queue_up = queue.Queue() data = self.input stack_down.append(data) for i in range(len(self.down_modules) - 1): data = self.down_modules[i](data) stack_down.append(data) data = self.down_modules[-1](data) queue_up.put(data) assert len( self.inner_modules ) == 2, "For this segmentation model, we except 2 distinct inner" data_inner = self.inner_modules[0](data) data_inner_2 = self.inner_modules[1](stack_down[3]) for i in range(len(self.up_modules) - 1): data = self.up_modules[i]((queue_up.get(), stack_down.pop())) queue_up.put(data) last_feature = torch.cat([ data.x, data_inner.x.repeat(1, 1, data.x.shape[-1]), data_inner_2.x.repeat(1, 1, data.x.shape[-1]) ], dim=1) if self._use_category: cat_one_hot = F.one_hot(self.category, self._num_categories).float().transpose( 1, 2) last_feature = torch.cat((last_feature, cat_one_hot), dim=1) self.output = self.FC_layer(last_feature).transpose( 1, 2).contiguous().view((-1, self._num_classes)) # Compute loss if self._weight_classes is not None: self._weight_classes = self._weight_classes.to(self.output.device) if self.labels is not None: self.loss_seg = F.cross_entropy(self.output, self.labels, weight=self._weight_classes) self.data_visual = self.input self.data_visual.y = torch.reshape(self.labels, data.pos.shape[0:2]) self.data_visual.pred = torch.max(self.output, -1)[1].reshape(data.pos.shape[0:2]) return self.output def backward(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # caculate the intermediate results if necessary; here self.output has been computed during function <forward> # calculate loss given the input and intermediate results self.loss_seg.backward()