Пример #1
0
    def __init__(self, in_planes, out_planes, stride):
        super(projfeat3d, self).__init__()
        self.flagTS = GLOBAL.torch_batch_normal_track_stat()

        self.stride = stride
        self.conv1 = nn.Conv2d(in_planes, out_planes, (1,1), padding=(0,0), stride=stride[:2],bias=False)
        self.bn = nn.BatchNorm2d(out_planes, track_running_stats=self.flagTS)
Пример #2
0
    def __init__(self, inCh, trs=None):
        """
        trs stands for track_running_stats.
        """
        super(FeatureNormalization, self).__init__()

        assert inCh > 0

        if (trs is None):
            self.model = nn.BatchNorm2d(
                inCh,
                track_running_stats=GLOBAL.torch_batch_normal_track_stat())
        else:
            self.model = nn.BatchNorm2d(inCh, track_running_stats=trs)
    def __init__(self, initialChannels=32, freeze=False):
        super(UNet, self).__init__(freeze=freeze)

        self.flagTS = GLOBAL.torch_batch_normal_track_stat()

        self.inplanes = initialChannels

        # Encoder
        self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
                                                 padding=1, stride=2, bias=False)
        self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
                                                 padding=1, stride=1, bias=False)
        self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
                                                 padding=1, stride=1, bias=False)
        # Vanilla Residual Blocks
        self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
        self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)

        self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
        self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
        self.pyramid_pooling = pyramidPooling(128, None,  fusion_mode='sum', model_name='icnet')
        # Iconvs
        self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
                                     conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False))
        self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
                                                 padding=1, stride=1, bias=False)
        self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
                                     conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False))
        self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
                                                 padding=1, stride=1, bias=False)
        self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
                                     conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False))
        self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False)

        self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=32, padding=0,stride=1,bias=False)
        self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=16, padding=0,stride=1,bias=False)
        self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=16, padding=0,stride=1,bias=False)
        self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=16, padding=0,stride=1,bias=False)

        # Must be called at the end of __init__().
        self.update_freeze()
Пример #4
0
    def __init__(self, in_channels, n_filters, k_size,  stride, padding, bias=True, dilation=1, with_bn=True):
        super(conv2DBatchNorm, self).__init__()

        self.flagTS = GLOBAL.torch_batch_normal_track_stat()

        if dilation > 1:
            conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
                                 padding=padding, stride=stride, bias=bias, dilation=dilation)

        else:
            conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
                                 padding=padding, stride=stride, bias=bias, dilation=1)


        if with_bn:
            self.cb_unit = nn.Sequential(conv_mod,
                                         nn.BatchNorm2d(int(n_filters), track_running_stats=self.flagTS),)
        else:
            self.cb_unit = nn.Sequential(conv_mod,)
Пример #5
0
 def disable_batch_norm_track_running_stat(self):
     self.check_frame()
     self.frame.logger.info('Use track_running_stats=False.')
     modelGLOBAL.torch_batch_normal_track_stat(False)
Пример #6
0
 def __init__(self, inCh, trs=None):
     super(FeatureNorm3D, self).__init__()
     self.model = nn.BatchNorm3d( inCh, track_running_stats=GLOBAL.torch_batch_normal_track_stat() ) \
         if trs is None \
         else nn.BatchNorm3d( inCh, track_running_stats=trs )
Пример #7
0
    def __init__(self, 
        edChannels = [
            [  3, 16, 16, 16, 16 ],
            [ 16, 16, 16, 16, 16 ],
            [ 16, 16, 16, 16, 16 ],
            [ 16, 16, 16, 16, 16 ]
        ],
        freeze=False):
        '''
        edChannels (list of lists): Channel specification for every layer. 
            [ eIn, eOut, dOut, up, out ]
        '''
        super(UNetOneHalf, self).__init__(
            levels=[2, 4, 8, 16],
            freeze=freeze)

        N = len(levels)
        assert( N == len(edChannels) ), \
            f'Wrong level and channel specification. levels = {levels}, edChannels = {edChannels}'

        self.flagTS = GLOBAL.torch_batch_normal_track_stat()
        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        # Encoders.
        self.encoders = nn.ModuleList()
        for i in range(N):
            inCh  = edChannels[i][CH_IDX_E_IN]
            outCh = edChannels[i][CH_IDX_E_OUT]
            self.encoders.append(
                cm.ResidualBlock( inCh, outCh, 
                    stride=2, downsample=
                        cm.Conv( inCh, outCh, k=1, s=2, p=0, 
                            normLayer=cm.FeatureNormalization( outCh ) )
                )
            )

        # Decoders.
        self.decoders = nn.ModuleList()
        for i in range(N-1, -1, -1):
            if ( i == N-1 ):
                inCh = edChannels[i][CH_IDX_E_OUT]
            else:
                inCh = edChannels[i][CH_IDX_E_OUT] \
                     + edChannels[i-1][CH_IDX_U_OUT]

            outCh = edChannels[i][CH_IDX_D_OUT]

            self.decoders.append(
                cm.Conv_W( inCh, outCh, 
                    normLayer=cm.FeatureNormalization( outCh ),
                    activation=nn.ReLU(inplace=self.flagReLUInplace) ) )

        self.decoders = self.decoders[::-1]

        # Up-feature layers.
        self.ups = nn.ModuleList()
        for i in range( N-1, 0, -1 ):
            inCh  = edChannels[i][CH_IDX_D_OUT]
            outCh = edChannels[i][CH_IDX_U_OUT]
            self.ups.append(
                cm.Interpolate2D_FixedScale(2),
                cm.Conv_W( inCh, outCh, 
                    normLayer=cm.FeatureNormalization( outCh ),
                    activation=nn.ReLU(inplace=self.flagReLUInplace) ) )

        self.ups = self.ups[::-1]

        # Finale layers.
        self.finals = nn.ModuleList()
        for i in range( N-1, -1, -1 ):
            inCh  = edChannels[i][CH_IDX_U_OUT]
            outCh = edChannels[i][CH_IDX_F_OUT]
            self.finals.append( 
                cm.Conv_W( inCh, outCh, 
                    normLayer=cm.FeatureNormalization( outCh ),
                    activation=nn.ReLU(inplace=self.flagReLUInplace) ) )

        self.finals = self.finals[::-1]

        # Middle.
        self.middle = cm.ResidualBlock(
            edChannels[-1][CH_IDX_E_OUT], edChannels[-1][CH_IDX_E_OUT], 
                lastActivation=nn.ReLU(inplace=self.flagReLUInplace) )