Пример #1
0
def dummy_train(device, mixed=False):
    model = nn.Sequential(
        spnn.Conv3d(4, 32, kernel_size=3, stride=1), spnn.BatchNorm(32),
        spnn.ReLU(True), spnn.Conv3d(32, 64, kernel_size=2, stride=2),
        spnn.BatchNorm(64), spnn.ReLU(True),
        spnn.Conv3d(64, 64, kernel_size=2, stride=2, transpose=True),
        spnn.BatchNorm(64), spnn.ReLU(True),
        spnn.Conv3d(64, 32, kernel_size=3, stride=1), spnn.BatchNorm(32),
        spnn.ReLU(True), spnn.Conv3d(32, 10, kernel_size=1)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)
    scaler = torch.cuda.amp.GradScaler(enabled=mixed)

    print('Starting dummy training...')
    for i in range(10):
        optimizer.zero_grad()
        feed_dict = generate_batched_random_point_clouds()
        inputs = feed_dict['lidar'].to(device)
        targets = feed_dict['targets'].F.to(device).long()
        with torch.cuda.amp.autocast(enabled=mixed):
            outputs = model(inputs)
            loss = criterion(outputs.F, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        print('[step %d] loss = %f.' % (i, loss.item()))
    print('Finished dummy training!')
Пример #2
0
    def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
        super().__init__()
        self.net = nn.Sequential(
            spnn.Conv3d(inc,
                        outc,
                        kernel_size=ks,
                        dilation=dilation,
                        stride=stride),
            spnn.BatchNorm(outc),
            spnn.ReLU(True),
            spnn.Conv3d(outc,
                        outc,
                        kernel_size=ks,
                        dilation=dilation,
                        stride=1),
            spnn.BatchNorm(outc),
        )

        self.downsample = (nn.Sequential() if (
            inc == outc and stride == 1) else nn.Sequential(
                spnn.Conv3d(
                    inc, outc, kernel_size=1, dilation=1, stride=stride),
                spnn.BatchNorm(outc)))

        self.relu = spnn.ReLU(True)
    def __init__(self, **kwargs):
        super().__init__()

        cr = kwargs.get('cr', 1.0)
        cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
        cs = [int(cr * x) for x in cs]

        if 'pres' in kwargs and 'vres' in kwargs:
            self.pres = kwargs['pres']
            self.vres = kwargs['vres']

        self.stem = nn.Sequential(
            spnn.Conv3d(kwargs['input_channel'], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]),
            spnn.ReLU(True),
            spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]),
            spnn.ReLU(True))

        self.stage1 = nn.Sequential(
            BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
        )

        self.stage2 = nn.Sequential(
            BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
        )

        self.stage3 = nn.Sequential(
            BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
        )

        self.stage4 = nn.Sequential(
            BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
        )
        self.avg_pool = spnn.GlobalAveragePooling()
        self.classifier = nn.Sequential(nn.Linear(cs[4], kwargs['num_classes']))
        self.point_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(cs[0], cs[4]),
                nn.BatchNorm1d(cs[4]),
                nn.ReLU(True),
            ),
        ])

        self.weight_initialization()
        self.dropout = nn.Dropout(0.3, True)
Пример #4
0
    def __init__(self,
                 inc,
                 outc,
                 cr_bounds=[0.25, 1.0],
                 ks=3,
                 stride=1,
                 dilation=1):
        # make sure first run random_sample, then constrain_output_channel
        super().__init__()
        self.inc = inc
        self.outc = outc
        self.cr_bounds = cr_bounds
        self.stride = stride

        self.use_skip_conn = (inc == outc and stride == 1)
        self.net_depth = None

        # can separate the last layer from self.net
        self.net = RandomDepth(*[
            DynamicConvolutionBlock(inc, outc, cr_bounds, ks, stride, dilation,
                                    False),
            DynamicConvolutionBlock(outc, outc, cr_bounds, ks, stride,
                                    dilation, True)
        ],
                               depth_min=2)

        self.downsample = nn.Sequential() if self.use_skip_conn else \
                DynamicConvolutionBlock(inc, outc, cr_bounds, ks=1, stride=1, dilation=1, no_relu=True)

        self.relu = spnn.ReLU(True)
        self.runtime_inc = None
Пример #5
0
 def __init__(self,
              inc,
              outc,
              cr_bounds=[0.25, 1.0],
              ks=3,
              stride=1,
              dilation=1,
              no_relu=False):
     super().__init__()
     self.inc = inc
     self.outc = outc
     self.ks = ks
     self.s = stride
     self.cr_bounds = cr_bounds
     self.no_relu = no_relu
     self.net = nn.Sequential(
         OrderedDict([
             ('conv',
              SparseDynamicConv3d(inc,
                                  outc,
                                  kernel_size=ks,
                                  dilation=dilation,
                                  stride=stride)),
             ('bn', SparseDynamicBatchNorm(outc)),
             ('act',
              spnn.ReLU(True) if not self.no_relu else nn.Sequential())
         ]))
     self.runtime_inc = None
     self.runtime_outc = None
     self.in_channel_constraint = None
Пример #6
0
 def __init__(self,
              inc,
              outc,
              ks=3,
              stride=1,
              dilation=1,
              no_relu=False,
              transpose=False):
     super().__init__()
     self.inc = inc
     self.outc = outc
     self.ks = ks
     self.no_relu = no_relu
     self.net = nn.Sequential(
         OrderedDict([
             ('conv',
              spnn.Conv3d(inc,
                          outc,
                          kernel_size=ks,
                          dilation=dilation,
                          stride=stride,
                          transpose=transpose)),
             ('bn', spnn.BatchNorm(outc)),
             ('act',
              spnn.ReLU(True) if not self.no_relu else nn.Sequential())
         ]))
     self.init_weights()
Пример #7
0
 def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
     super().__init__()
     self.net = nn.Sequential(
         spnn.Conv3d(inc,
                     outc,
                     kernel_size=ks,
                     dilation=dilation,
                     stride=stride), spnn.BatchNorm(outc), spnn.ReLU(True))
Пример #8
0
 def __init__(self, inc, outc, cr_bounds=[0.25, 1.0], ks=3, stride=1):
     super().__init__()
     self.inc = inc
     self.outc = outc
     self.ks = ks
     self.s = stride
     self.cr_bounds = cr_bounds
     self.net = nn.Sequential(
         OrderedDict([('conv',
                       SparseDynamicConv3d(inc,
                                           outc,
                                           kernel_size=ks,
                                           stride=stride,
                                           transpose=True)),
                      ('bn', SparseDynamicBatchNorm(outc)),
                      ('act', spnn.ReLU(True))]))
     self.runtime_inc = None
     self.runtime_outc = None
     self.in_channel_constraint = None
Пример #9
0
 def __init__(self, net, downsample):
     self.net = net
     self.downsample = downsample
     self.relu = spnn.ReLU(True)
Пример #10
0
    def __init__(self, option, model_type, dataset, modules):
        super(PVCNN, self).__init__()

        cr = option.cr
        self.vres = option.vres
        self.num_classes = dataset.num_classes
        self.num_features = dataset.feature_dimension

        cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
        cs = [int(cr * x) for x in cs]

        self.stem = nn.Sequential(
            spnn.Conv3d(self.num_features, cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]),
            spnn.ReLU(True),
            spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]),
            spnn.ReLU(True),
        )

        self.stage1 = nn.Sequential(
            BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
        )

        self.stage2 = nn.Sequential(
            BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
        )

        self.stage3 = nn.Sequential(
            BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
        )

        self.stage4 = nn.Sequential(
            BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
        )

        self.up1 = nn.ModuleList(
            [
                BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2),
                nn.Sequential(
                    ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, dilation=1),
                    ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1),
                ),
            ]
        )

        self.up2 = nn.ModuleList(
            [
                BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2),
                nn.Sequential(
                    ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, dilation=1),
                    ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1),
                ),
            ]
        )

        self.up3 = nn.ModuleList(
            [
                BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2),
                nn.Sequential(
                    ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, dilation=1),
                    ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1),
                ),
            ]
        )

        self.up4 = nn.ModuleList(
            [
                BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
                nn.Sequential(
                    ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, dilation=1),
                    ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
                ),
            ]
        )

        self.classifier = nn.Sequential(nn.Linear(cs[8], self.num_classes))

        self.point_transforms = nn.ModuleList(
            [
                nn.Sequential(nn.Linear(cs[0], cs[4]), nn.BatchNorm1d(cs[4]), nn.ReLU(True),),
                nn.Sequential(nn.Linear(cs[4], cs[6]), nn.BatchNorm1d(cs[6]), nn.ReLU(True),),
                nn.Sequential(nn.Linear(cs[6], cs[8]), nn.BatchNorm1d(cs[8]), nn.ReLU(True),),
            ]
        )

        self.weight_initialization()
        self.dropout = nn.Dropout(0.3, True)

        self.loss_names = ["loss_seg"]
Пример #11
0
    def __init__(self, cfg: OmegaConf):
        super().__init__()
        self.hparams.update(cfg)
        if is_rank_zero():
            self.save_hyperparameters(cfg)

        #self.hparams.optimizer._target_ = 'calo_cluster.training.optimizers.adam_factory'
        #self.hparams.scheduler._target_ = 'calo_cluster.training.schedulers.one_cycle_lr_factory'
        self.optimizer_factory = hydra.utils.instantiate(
            self.hparams.optimizer)
        self.scheduler_factory = hydra.utils.instantiate(
            self.hparams.scheduler)

        task = self.hparams.task
        assert task in ('instance', 'semantic', 'panoptic')
        if task == 'instance' or task == 'panoptic':
            self.embed_criterion = hydra.utils.instantiate(
                self.hparams.embed_criterion)
        if task == 'semantic' or task == 'panoptic':
            self.semantic_criterion = hydra.utils.instantiate(
                self.hparams.semantic_criterion)

        cs = [int(self.hparams.model.cr * x) for x in self.hparams.model.cs]

        self.stem = nn.Sequential(
            spnn.Conv3d(self.hparams.dataset.num_features,
                        cs[0],
                        kernel_size=3,
                        stride=1), spnn.BatchNorm(cs[0]), spnn.ReLU(True),
            spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True))

        self.stage1 = nn.Sequential(
            BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
        )

        self.stage2 = nn.Sequential(
            BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
        )

        self.stage3 = nn.Sequential(
            BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
        )

        self.stage4 = nn.Sequential(
            BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
        )

        self.up1 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1),
            )
        ])

        self.up2 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1),
            )
        ])

        self.up3 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1),
            )
        ])

        if task == 'semantic' or task == 'panoptic':
            self.c_up4 = nn.ModuleList([
                BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
                nn.Sequential(
                    ResidualBlock(cs[8] + cs[0],
                                  cs[8],
                                  ks=3,
                                  stride=1,
                                  dilation=1),
                    ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
                )
            ])
            self.c_point_transform = nn.Sequential(
                nn.Linear(cs[6], cs[8]),
                nn.BatchNorm1d(cs[8]),
                nn.ReLU(True),
            )
            self.c_lin = nn.Sequential(
                nn.Linear(cs[8], self.hparams.dataset.num_classes))
        if task == 'instance' or task == 'panoptic':
            self.e_up4 = nn.ModuleList([
                BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
                nn.Sequential(
                    ResidualBlock(cs[8] + cs[0],
                                  cs[8],
                                  ks=3,
                                  stride=1,
                                  dilation=1),
                    ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
                )
            ])
            self.e_point_transform = nn.Sequential(
                nn.Linear(cs[6], cs[8]),
                nn.BatchNorm1d(cs[8]),
                nn.ReLU(True),
            )
            self.e_lin = nn.Sequential(
                nn.Linear(cs[8], self.hparams.model.embed_dim))

        self.point_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(cs[0], cs[4]),
                nn.BatchNorm1d(cs[4]),
                nn.ReLU(True),
            ),
            nn.Sequential(
                nn.Linear(cs[4], cs[6]),
                nn.BatchNorm1d(cs[6]),
                nn.ReLU(True),
            )
        ])

        self.weight_initialization()
        self.dropout = nn.Dropout(0.3, True)
Пример #12
0
    def __init__(self, num_classes, macro_depth_constraint, **kwargs):
        super().__init__()
        self.pres = kwargs.get('pres', 0.05)
        self.vres = kwargs.get('vres', 0.05)
        self.cr_bounds = [
            0.125, 1.0
        ] if 'cr_bounds' not in kwargs else kwargs['cr_bounds']
        self.up_cr_bounds = [
            0.125, 1.0
        ] if 'up_cr_bounds' not in kwargs else kwargs['up_cr_bounds']
        self.trans_cr_bounds = [
            0.125, 1.0
        ] if 'trans_cr_bounds' not in kwargs else kwargs['trans_cr_bounds']

        if 'output_channels_ub' not in kwargs:
            self.output_channels_ub = self.output_channels
        else:
            self.output_channels_ub = kwargs['output_channels_ub']

        if 'output_channels_lb' in kwargs:
            self.output_channels_lb = kwargs['output_channels_lb']
        
        base_channels = self.base_channels
        output_channels = self.output_channels
        output_channels_lb = self.output_channels_lb

        self.stem = nn.Sequential(
            spnn.Conv3d(4, base_channels, kernel_size=3, stride=1),
            spnn.BatchNorm(base_channels), spnn.ReLU(True),
            spnn.Conv3d(base_channels,
                                 base_channels,
                                 kernel_size=3,
                                 stride=1), spnn.BatchNorm(base_channels),
            spnn.ReLU(True))

        num_down_stages = self.num_down_stages

        stages = []
        for i in range(1, num_down_stages + 1):
            stages.append(
                nn.Sequential(
                    OrderedDict([
                        ('transition',
                         DynamicConvolutionBlock(
                             base_channels,
                             base_channels,
                             cr_bounds=self.trans_cr_bounds,
                             ks=2,
                             stride=2,
                             dilation=1)),
                        (
                            'feature',
                            RandomDepth(
                                *[
                                    DynamicResidualBlock(
                                        base_channels,
                                        output_channels[i],
                                        cr_bounds=self.cr_bounds,
                                        ks=3,
                                        stride=1,
                                        dilation=1),
                                    DynamicResidualBlock(
                                        output_channels[i],
                                        output_channels[i],
                                        cr_bounds=self.cr_bounds,
                                        ks=3,
                                        stride=1,
                                        dilation=1)
                                ],
                                depth_min=macro_depth_constraint))
                    ])))
            base_channels = output_channels[i]

        self.downsample = nn.ModuleList(stages)

        # take care of weight sharing after concat!
        upstages = []
        for i in range(1, num_down_stages + 1):
            new_base_channels = output_channels[num_down_stages + i]
            upstages.append(
                nn.Sequential(
                    OrderedDict([
                        ('transition',
                         DynamicDeconvolutionBlock(base_channels,
                                                   new_base_channels,
                                                   cr_bounds=self.up_cr_bounds,
                                                   ks=2,
                                                   stride=2)),
                        (
                            'feature',
                            RandomDepth(
                                *[
                                    DynamicResidualBlock(
                                        output_channels[num_down_stages - i] +
                                        new_base_channels,
                                        new_base_channels,
                                        cr_bounds=self.up_cr_bounds,
                                        ks=3,
                                        stride=1,
                                        dilation=1),
                                    DynamicResidualBlock(
                                        new_base_channels,
                                        new_base_channels,
                                        cr_bounds=self.up_cr_bounds,
                                        ks=3,
                                        stride=1,
                                        dilation=1)
                                ],
                                depth_min=macro_depth_constraint))
                    ])))
            base_channels = new_base_channels

        self.upsample = nn.ModuleList(upstages)

        self.point_transforms = nn.ModuleList([
            DynamicLinearBlock(output_channels[0],
                               output_channels[num_down_stages],
                               bias=True,
                               no_relu=False,
                               no_bn=False),
            DynamicLinearBlock(output_channels[num_down_stages],
                               output_channels[num_down_stages + 2],
                               bias=True,
                               no_relu=False,
                               no_bn=False),
            DynamicLinearBlock(output_channels[num_down_stages + 2],
                               output_channels[-1],
                               bias=True,
                               no_relu=False,
                               no_bn=False),
        ])

        self.classifier = DynamicLinear(output_channels[-1], num_classes)
        self.classifier.set_output_channel(num_classes)


        self.dropout = nn.Dropout(0.3, True)
        self.weight_initialization()
Пример #13
0
    def __init__(self, **kwargs):
        super().__init__()

        self.dropout = kwargs['dropout']

        cr = kwargs.get('cr', 1.0)
        cs = [32, 64, 128, 96, 96]
        cs = [int(cr * x) for x in cs]

        if 'pres' in kwargs and 'vres' in kwargs:
            self.pres = kwargs['pres']
            self.vres = kwargs['vres']

        self.stem = nn.Sequential(
            spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True)
        )

        self.stage1 = nn.Sequential(
            BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
        )

        self.stage2 = nn.Sequential(
            BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
        )

        self.up1 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
            )
        ])

        self.up2 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
            )
        ])

        self.point_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(cs[0], cs[2]),
                nn.BatchNorm1d(cs[2]),
                nn.ReLU(True),
            ),
            nn.Sequential(
                nn.Linear(cs[2], cs[4]),
                nn.BatchNorm1d(cs[4]),
                nn.ReLU(True),
            )
        ])

        self.weight_initialization()

        if self.dropout:
            self.dropout = nn.Dropout(0.3, True)
Пример #14
0
    def __init__(self, **kwargs):
        super().__init__()

        cr = kwargs.get('cr', 1.0)
        cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
        cs = [int(cr * x) for x in cs]
        self.run_up = kwargs.get('run_up', True)

        self.stem = nn.Sequential(
            spnn.Conv3d(4, cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True),
            spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True))

        self.stage1 = nn.Sequential(
            BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
        )

        self.stage2 = nn.Sequential(
            BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1))

        self.stage3 = nn.Sequential(
            BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
        )

        self.stage4 = nn.Sequential(
            BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
        )

        self.up1 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1),
            )
        ])

        self.up2 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1),
            )
        ])

        self.up3 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1),
            )
        ])

        self.up4 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1,
                              dilation=1),
                ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
            )
        ])

        self.classifier = nn.Sequential(nn.Linear(cs[8],
                                                  kwargs['num_classes']))

        self.point_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(cs[0], cs[4]),
                nn.BatchNorm1d(cs[4]),
                nn.ReLU(True),
            ),
            nn.Sequential(
                nn.Linear(cs[4], cs[6]),
                nn.BatchNorm1d(cs[6]),
                nn.ReLU(True),
            ),
            nn.Sequential(
                nn.Linear(cs[6], cs[8]),
                nn.BatchNorm1d(cs[8]),
                nn.ReLU(True),
            )
        ])

        self.weight_initialization()
        self.dropout = nn.Dropout(0.3, True)