Beispiel #1
0
    def __init__(self,
                 inCh,
                 levels,
                 maxKFactor=0.5,
                 lastActivation=None,
                 flagNearest=False):
        super(SpatialPyramidPooling, self).__init__()

        # Global settings.
        self.flagAlighCorners = GLOBAL.torch_align_corners()
        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        self.inCh = inCh

        # Pooling levels.
        self.levels = int(levels)
        assert (self.levels > 0)

        # Kernel size factor.
        assert (0 < maxKFactor <= 1)
        self.maxKFactor = maxKFactor

        # Convolusion layers for the pooling levels.
        self.poolingConvs = self.make_pooling_convs()

        # Last activation.
        self.lastActivation = lastActivation

        # The interpolation mode.
        self.interMode = 'nearest' if flagNearest else self.get_interpolate_mode(
        )
Beispiel #2
0
    def __init__(self, s, flagNearest=False):
        super(Interpolate2D_FixedScale, self).__init__()

        self.s = s
        self.flagAlignCorners = GLOBAL.torch_align_corners()

        self.mode = 'nearest' if flagNearest else 'bilinear'
Beispiel #3
0
    def __init__(self,
                 inChs,
                 interChs,
                 stride=1,
                 downsample=None,
                 dilation=1,
                 lastActivation=None):
        super(ResidualBlock, self).__init__()

        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        if dilation > 1:
            padding = dilation
        else:
            padding = 1

        self.conv0 = Conv(inChs,
                          interChs,
                          k=3,
                          s=stride,
                          p=padding,
                          d=dilation,
                          normLayer=FeatureNormalization(interChs),
                          activation=nn.ReLU(inplace=self.flagReLUInplace))
        self.conv1 = Conv_W(interChs,
                            interChs,
                            normLayer=FeatureNormalization(interChs))

        self.downsample = downsample
        self.stride = stride

        self.lastActivation = lastActivation \
            if lastActivation is not None \
            else None
    def __init__(self, in_planes, out_planes, stride):
        super(projfeat3d, self).__init__()
        self.flagTS = GLOBAL.torch_batch_normal_track_stat()

        self.stride = stride
        self.conv1 = nn.Conv2d(in_planes, out_planes, (1,1), padding=(0,0), stride=stride[:2],bias=False)
        self.bn = nn.BatchNorm2d(out_planes, track_running_stats=self.flagTS)
    def __init__(self, 
        nConvs, inCh, interCh, outCh,
        baseStride=(1,1,1), nStrides=1, 
        outputUpSampledFeat=False, pooling=False ):
        super(DecoderBlock, self).__init__()

        # Get the global settings.
        self.flagAlignCorners = GLOBAL.torch_align_corners()
        self.flagReLUInplace  = GLOBAL.torch_relu_inplace()

        # Prepare the list of strides.
        assert( nConvs >= nStrides )
        strideList = [baseStride] * nStrides + [(1,1,1)] * (nConvs - nStrides)

        # Create the the convolusion layers.
        convs = [ SepConv3DBlock( inCh, interCh, stride=strideList[0] ) ]
        for i in range(1, nConvs):
            convs.append( SepConv3DBlock( interCh, interCh, stride=strideList[i] ) )
        self.entryConvs = WrappedModule( nn.Sequential(*convs) )
        self.append_init_here( self.entryConvs )

        # Classification layer.
        self.classify = WrappedModule(
            nn.Sequential(
                cm3d.Conv3D_W( interCh, interCh, 
                    normLayer=cm3d.FeatureNorm3D(interCh), 
                    activation=nn.ReLU(inplace=self.flagReLUInplace) ), 
                cm3d.Conv3D_W(interCh, outCh, bias=True) ) )
        self.append_init_here(self.classify)

        # Feature up-sample setting.
        self.featUpSampler = None
        if outputUpSampledFeat:
            self.featUpSampler = WrappedModule(
                nn.Sequential(
                    cm3d.Interpolate3D_FixedScale(2),
                    cm3d.Conv3D_W( interCh, interCh//2, 
                        normLayer=cm3d.FeatureNorm3D(interCh//2), 
                        activation=nn.ReLU(inplace=self.flagReLUInplace) ) ) )
            self.append_init_here(self.featUpSampler)

        # Pooling.
        if pooling:
            self.spp = SPP3D( interCh, levels=4 )
            self.append_init_here(self.spp)
        else:
            self.spp = None
Beispiel #6
0
    def __init__(self, 
        flagMaskedLoss=True, 
        flagIntNearest=False):
        super(LossComputer, self).__init__()

        self.flagAlignCorners = modelGLOBAL.torch_align_corners()
        self.flagIntNearest   = flagIntNearest
        self.flagMaskedLoss   = flagMaskedLoss # Set True to mask the true disparity larger than self.trueDispMask.
Beispiel #7
0
    def __init__(self, workingDir, conf, frame=None, modelName='Stereo'):
        self.conf = conf # The configuration dictionary.
        
        self.wd        = workingDir
        self.frame     = frame
        self.modelName = modelName

        # NN.
        self.countTrain = 0
        self.countTest  = 0

        self.flagAlignCorners = modelGLOBAL.torch_align_corners()
        self.flagIntNearest   = False

        self.trainIntervalAccWrite = 10    # The interval to write the accumulated values.
        self.trainIntervalAccPlot  = 1     # The interval to plot the accumulate values.
        self.flagUseIntPlotter     = False # The flag of intermittent plotter.

        self.flagCPU   = False
        self.multiGPUs = False

        self.readModelString     = ""
        self.readOptimizerString = ""
        self.autoSaveModelLoops  = 0 # The number of loops to perform an auto-saving of the model. 0 for disable.
        self.autoSnapLoops       = 100 # The number of loops to perform an auto-snap.

        self.optimizer = None

        self.flagTest  = False # Should be set to True when testing.

        self.flagRandomSeedSet = False

        # Specified by conf.
        self.trueDispMask        = conf['tt']['trueDispMask']
        self.model               = None # make_object later.
        self.dataloader          = None # make_object later.
        self.optType             = conf['tt']['optType'] # The optimizer type. adam, sgd.
        # Learning rate scheduler.
        self.flagUseLRS          = conf['tt']['flagUseLRS']
        self.learningRate        = conf['tt']['lr']
        self.lrs                 = None # make_object later.
        # True value and loss.
        self.trueValueGenerator  = None # make_object later.
        self.lossComputer        = None # make_object later.
        self.testResultSubfolder = conf['tt']['testResultSubfolder']
        # Test figure generator.
        self.testFigGenerator = None

        # Temporary values during traing and testing.
        self.ctxInputs     = None
        self.ctxOutputs    = None
        self.ctxTrueValues = None
        self.ctxLossValues = None

        # InfoUpdaters.
        self.trainingInfoUpdaters = []
        self.testingInfoUpdaters  = []
    def __init__(self,
                 nconvs,
                 inchannelF,
                 channelF,
                 stride=(1, 1, 1),
                 up=False,
                 nstride=1,
                 pool=False):
        super(decoderBlock, self).__init__()

        self.flagAlignCorners = GLOBAL.torch_align_corners()
        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        self.pool = pool
        stride = [stride] * nstride + [(1, 1, 1)] * (nconvs - nstride)
        self.convs = [sepConv3dBlock(inchannelF, channelF, stride=stride[0])]
        for i in range(1, nconvs):
            self.convs.append(
                sepConv3dBlock(channelF, channelF, stride=stride[i]))
        self.convs = nn.Sequential(*self.convs)

        self.classify = nn.Sequential(
            sepConv3d(channelF, channelF, 3, (1, 1, 1), 1),
            nn.ReLU(inplace=self.flagReLUInplace),
            sepConv3d(channelF, 1, 3, (1, 1, 1), 1, bias=True))

        self.up = False
        if up:
            self.up = True
            self.up = nn.Sequential(
                nn.Upsample(scale_factor=(2, 2, 2),
                            mode='trilinear',
                            align_corners=self.flagAlignCorners),
                sepConv3d(channelF, channelF // 2, 3, (1, 1, 1), 1,
                          bias=False), nn.ReLU(inplace=self.flagReLUInplace))

        if pool:
            self.pool_convs = torch.nn.ModuleList([
                sepConv3d(channelF, channelF, 1, (1, 1, 1), 0),
                sepConv3d(channelF, channelF, 1, (1, 1, 1), 0),
                sepConv3d(channelF, channelF, 1, (1, 1, 1), 0),
                sepConv3d(channelF, channelF, 1, (1, 1, 1), 0)
            ])
    def __init__(self, in_planes, out_planes, stride=(1,1,1)):
        super(sepConv3dBlock, self).__init__()

        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        if in_planes == out_planes and stride==(1,1,1):
            self.downsample = None
        else:
            self.downsample = projfeat3d(in_planes, out_planes,stride)
        self.conv1 = sepConv3d(in_planes, out_planes, 3, stride, 1)
        self.conv2 = sepConv3d(out_planes, out_planes, 3, (1,1,1), 1)
Beispiel #10
0
    def __init__(self, in_channels, pool_sizes, model_name='pspnet', fusion_mode='cat', with_bn=True):
        super(pyramidPooling, self).__init__()

        self.flagAlignCorners = GLOBAL.torch_align_corners()
        self.flagReLUInplace  = GLOBAL.torch_relu_inplace()

        bias = not with_bn

        self.paths = []
        if pool_sizes is None:
            for i in range(4):
                self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, bias=bias, with_bn=with_bn))
        else:
            for i in range(len(pool_sizes)):
                self.paths.append(conv2DBatchNormRelu(in_channels, int(in_channels / len(pool_sizes)), 1, 1, 0, bias=bias, with_bn=with_bn))

        self.path_module_list = nn.ModuleList(self.paths)
        self.pool_sizes = pool_sizes
        self.model_name = model_name
        self.fusion_mode = fusion_mode
Beispiel #11
0
    def __init__(self, inCh, trs=False):
        """
        trs stands for track_running_stats.
        """
        super(InstanceNormalization, self).__init__()

        if (trs is None):
            self.model = nn.InstanceNorm2d(
                inCh,
                track_running_stats=GLOBAL.torch_inst_normal_track_stat())
        else:
            self.model = nn.InstanceNorm2d(inCh, track_running_stats=trs)
Beispiel #12
0
    def __init__(self, in_channels, n_filters, k_size,  stride, padding, bias=True, dilation=1, with_bn=True):
        super(conv2DBatchNormRelu, self).__init__()

        self.flagTS = GLOBAL.torch_batch_normal_track_stat()
        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        if dilation > 1:
            conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 
                                 padding=padding, stride=stride, bias=bias, dilation=dilation)

        else:
            conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 
                                 padding=padding, stride=stride, bias=bias, dilation=1)

        if with_bn:
            self.cbr_unit = nn.Sequential(conv_mod,
                                          nn.BatchNorm2d(int(n_filters), track_running_stats=self.flagTS),
                                          nn.LeakyReLU(0.1, inplace=self.flagReLUInplace),)
        else:
            self.cbr_unit = nn.Sequential(conv_mod,
                                          nn.LeakyReLU(0.1, inplace=self.flagReLUInplace),)
Beispiel #13
0
    def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1):
        super(residualBlock, self).__init__()

        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        if dilation > 1:
            padding = dilation
        else:
            padding = 1
        self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3,  stride, padding, bias=False,dilation=dilation)
        self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False)
        self.downsample = downsample
        self.stride = stride
        self.relu = nn.ReLU(inplace=self.flagReLUInplace)
Beispiel #14
0
    def __init__(self, inCh, trs=None):
        """
        trs stands for track_running_stats.
        """
        super(FeatureNormalization, self).__init__()

        assert inCh > 0

        if (trs is None):
            self.model = nn.BatchNorm2d(
                inCh,
                track_running_stats=GLOBAL.torch_batch_normal_track_stat())
        else:
            self.model = nn.BatchNorm2d(inCh, track_running_stats=trs)
    def __init__(self, inCh, outCh, stride=(1,1,1)):
        super(SepConv3DBlock, self).__init__()

        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        if inCh == outCh and stride==(1,1,1):
            self.downsample = None
        else:
            self.downsample = ProjFeat3D(inCh, outCh, stride)

        self.conv0 = cm3d.Conv3D( inCh, outCh, s=stride, 
            normLayer=cm3d.FeatureNorm3D(outCh), 
            activation=nn.ReLU(inplace=self.flagReLUInplace) )
        self.conv1 = cm3d.Conv3D_W( outCh, outCh, 
            normLayer=cm3d.FeatureNorm3D(outCh), 
            activation=nn.ReLU(inplace=self.flagReLUInplace) )
    def __init__(self, initialChannels=32, freeze=False):
        super(UNet, self).__init__(freeze=freeze)

        self.flagTS = GLOBAL.torch_batch_normal_track_stat()

        self.inplanes = initialChannels

        # Encoder
        self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
                                                 padding=1, stride=2, bias=False)
        self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
                                                 padding=1, stride=1, bias=False)
        self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
                                                 padding=1, stride=1, bias=False)
        # Vanilla Residual Blocks
        self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
        self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)

        self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
        self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
        self.pyramid_pooling = pyramidPooling(128, None,  fusion_mode='sum', model_name='icnet')
        # Iconvs
        self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
                                     conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False))
        self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
                                                 padding=1, stride=1, bias=False)
        self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
                                     conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False))
        self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
                                                 padding=1, stride=1, bias=False)
        self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
                                     conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False))
        self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
                                                 padding=1, stride=1, bias=False)

        self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=32, padding=0,stride=1,bias=False)
        self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=16, padding=0,stride=1,bias=False)
        self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=16, padding=0,stride=1,bias=False)
        self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=16, padding=0,stride=1,bias=False)

        # Must be called at the end of __init__().
        self.update_freeze()
Beispiel #17
0
    def __init__(self, nLayers=3, intermediateChannels=8, ):
        super(HalfSizeExtractor, self).__init__()

        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        modelList = [ 
            cm.Conv_Half( 3, intermediateChannels, 
                normLayer=cm.FeatureNormalization(intermediateChannels),
                activation=nn.ReLU(inplace=self.flagReLUInplace) ) ]

        for i in range(nLayers):
            if ( i == nLayers - 1):
                modelList.append(
                    cm.Conv_W( intermediateChannels, intermediateChannels, 
                    normLayer=cm.FeatureNormalization(intermediateChannels),
                    activation=nn.ReLU(inplace=self.flagReLUInplace) ) )
            else:
                modelList.append(
                    cm.Conv_W( intermediateChannels, intermediateChannels, 
                    normLayer=None,
                    activation=nn.ReLU(inplace=self.flagReLUInplace) ) )

        self.model = nn.Sequential( *modelList )
Beispiel #18
0
def get_padding_module(padding):
    global PADDING_MODULE_TYPES
    paddingMode = GLOBAL.padding_mode()
    paddingType = PADDING_MODULE_TYPES[paddingMode]
    return paddingType(padding=padding)
Beispiel #19
0
 def enable_last_regression_kernel_size_one(self):
     self.check_frame()
     self.frame.logger.info("Use last regression kernel size one. ")
     modelGLOBAL.last_regression_kernel_size_one(True)
Beispiel #20
0
 def set_padding_mode(self, mode):
     self.check_frame()
     self.frame.logger.info(f'Set padding mode to {mode}')
     modelGLOBAL.padding_mode(mode)
Beispiel #21
0
 def disable_batch_norm_track_running_stat(self):
     self.check_frame()
     self.frame.logger.info('Use track_running_stats=False.')
     modelGLOBAL.torch_batch_normal_track_stat(False)
Beispiel #22
0
 def disable_align_corners(self):
     self.check_frame()
     self.frame.logger.info("Use align_corners=False.")
     modelGLOBAL.torch_align_corners(False)
     self.flagAlignCorners = False
Beispiel #23
0
 def enable_align_corners(self):
     self.check_frame()
     self.frame.logger.info("Use align_corners=True.")
     modelGLOBAL.torch_align_corners(True)
     self.flagAlignCorners = True
Beispiel #24
0
 def __init__(self, inCh, trs=None):
     super(FeatureNorm3D, self).__init__()
     self.model = nn.BatchNorm3d( inCh, track_running_stats=GLOBAL.torch_batch_normal_track_stat() ) \
         if trs is None \
         else nn.BatchNorm3d( inCh, track_running_stats=trs )
Beispiel #25
0
    def __init__(self, initialChannels=32, flagEntryPool=True, freeze=False):
        super(UNet, self).__init__(levels=[8, 16, 32, 64], freeze=freeze)

        # self.flagTS = GLOBAL.torch_batch_normal_track_stat()
        self.flagReLUInplace = GLOBAL.torch_relu_inplace()

        self.flagEntryPool = flagEntryPool
        if (not self.flagEntryPool):
            self.levels = [level // 2 for level in self.levels]

        self.inplanes = initialChannels

        # Encoder
        self.convBnReLU1_1 = cm.Conv_Half(
            3,
            16,
            normLayer=cm.FeatureNormalization(16),
            activation=nn.ReLU(inplace=self.flagReLUInplace))
        self.convBnReLU1_2 = cm.Conv_W(
            16,
            16,
            normLayer=cm.FeatureNormalization(16),
            activation=nn.ReLU(inplace=self.flagReLUInplace))
        self.convBnReLU1_3 = cm.Conv_W(
            16,
            32,
            normLayer=cm.FeatureNormalization(32),
            activation=nn.ReLU(inplace=self.flagReLUInplace))

        # Vanilla Residual Blocks
        self.resBlock3 = self._make_layer(cm.ResidualBlock,
                                          interCh=64,
                                          stride=2)
        self.resBlock5 = self._make_layer(cm.ResidualBlock,
                                          interCh=128,
                                          stride=2)
        self.resBlock6 = self._make_layer(cm.ResidualBlock,
                                          interCh=128,
                                          stride=2)
        self.resBlock7 = self._make_layer(cm.ResidualBlock,
                                          interCh=128,
                                          stride=2)

        # This does not make sense with very small feature size.
        self.pyramidPooling = \
            pooling.SPP2D(128, levels=4,
                lastActivation=nn.ReLU(inplace=self.flagReLUInplace))

        # iConvs.
        self.upConv6 = nn.Sequential(
            cm.Interpolate2D_FixedScale(2),
            cm.Conv_W(128,
                      64,
                      normLayer=cm.FeatureNormalization(64),
                      activation=nn.ReLU(inplace=self.flagReLUInplace)))
        self.iConv5 = cm.Conv_W(
            192,
            128,
            normLayer=cm.FeatureNormalization(128),
            activation=nn.ReLU(inplace=self.flagReLUInplace))

        self.upConv5 = nn.Sequential(
            cm.Interpolate2D_FixedScale(2),
            cm.Conv_W(128,
                      64,
                      normLayer=cm.FeatureNormalization(64),
                      activation=nn.ReLU(inplace=self.flagReLUInplace)))
        self.iConv4 = cm.Conv_W(
            192,
            128,
            normLayer=cm.FeatureNormalization(128),
            activation=nn.ReLU(inplace=self.flagReLUInplace))

        self.upConv4 = nn.Sequential(
            cm.Interpolate2D_FixedScale(2),
            cm.Conv_W(128,
                      64,
                      normLayer=cm.FeatureNormalization(64),
                      activation=nn.ReLU(inplace=self.flagReLUInplace)))
        self.iConv3 = cm.Conv_W(
            128,
            64,
            normLayer=cm.FeatureNormalization(64),
            activation=nn.ReLU(inplace=self.flagReLUInplace))

        self.proj6 = cm.Conv_W(
            128,
            32,
            k=1,
            normLayer=cm.FeatureNormalization(32),
            activation=nn.ReLU(inplace=self.flagReLUInplace))
        self.proj5 = cm.Conv_W(
            128,
            16,
            k=1,
            normLayer=cm.FeatureNormalization(16),
            activation=nn.ReLU(inplace=self.flagReLUInplace))
        self.proj4 = cm.Conv_W(
            128,
            16,
            k=1,
            normLayer=cm.FeatureNormalization(16),
            activation=nn.ReLU(inplace=self.flagReLUInplace))
        self.proj3 = cm.Conv_W(
            64,
            16,
            k=1,
            normLayer=cm.FeatureNormalization(16),
            activation=nn.ReLU(inplace=self.flagReLUInplace))

        # Must be called at the end of __init__().
        self.update_freeze()
Beispiel #26
0
        os.path.dirname(
            os.path.dirname(
                os.path.dirname( _CF ) ) ) ) )

print(f'Adding {_PKG_PATH} to the package search path. ')

import sys
sys.path.insert(0, _PKG_PATH)

# Import the package.
import stereo
from stereo.models.globals import GLOBAL

# Configure the global settings for testing.
# Not the right setting. Only for testing.
GLOBAL.torch_align_corners(True)

if __name__ == '__main__':
    print(f'Hello, {os.path.basename(__file__)}! ')

    # Show the global settings.
    print(f'GLOBAL.torch_align_corners() = {GLOBAL.torch_align_corners()}')

    # The dummy dictionary.
    maxDisp = 192
    d = dict(type='HSMNet',
             maxdisp=maxDisp,
             clean=-1,
             featExtConfig=dict(type='UNet', initialChannels=32),
             costVolConfig=dict(type='CVDiff', refIsRight=False),
             dispRegConfigs=[
Beispiel #27
0
def selected_relu(x):
    # return F.selu(x, inplace=False)
    return F.leaky_relu(x, 0.1, inplace=GLOBAL.torch_relu_inplace())
Beispiel #28
0
    def __init__(self, 
        maxDisp=192,
        featExtConfig=None, 
        costVolConfig=None,
        costPrcConfig=None,
        dispRegConfigs=None,
        uncertainty=False,
        freeze=False):

        super(CostVolPrimitive, self).__init__(freeze=freeze)

        # Global setttings.
        self.flagAlignCorners = GLOBAL.torch_align_corners()

        # ========== Module definition. ==========
        self.maxDisp = maxDisp

        # Uncertainty setting.
        self.uncertainty = uncertainty

        # Feature extractor.
        if ( featExtConfig is None ):
            featExtConfig = UNet.get_default_init_args()

        self.featureExtractor = make_object(FEAT_EXT, featExtConfig)
        self.append_init_impl(self.featureExtractor)
        
        # Cost volume.
        if ( costVolConfig is None ):
            costVolConfig = CVDiff.get_default_init_args()
        
        self.costVol = make_object(COST_VOL, costVolConfig)
        self.append_init_impl(self.costVol)

        # Cost volume processing/regularization layers.
        if ( costPrcConfig is None ):
            costPrcConfig = C3D_SPP_CLS.get_default_init_args()
        self.costProcessor = make_object(COST_PRC, costPrcConfig)
        self.append_init_impl( self.costProcessor )

        # Disparity regressions.
        nLevels = self.featureExtractor.n_levels()
        if ( dispRegConfigs is None ):
            dispRegConfigs = [ ClsLinearCombination.get_default_init_args() for _ in range(nLevels)]
        else:
            assert( isinstance( dispRegConfigs, (tuple, list) ) ), \
                f'dispRegConfigs must be a tuple or list. It is {type(dispRegConfigs)}'
            assert( len(dispRegConfigs) == nLevels ), \
                f'len(dispRegConfigs) = {len(dispRegConfigs)}, nLevels = {nLevels}'

        self.dispRegList = nn.ModuleList()
        for config in dispRegConfigs:
            dispReg = make_object(DISP_REG, config)
            self.dispRegList.append( dispReg )
            self.append_init_impl( dispReg )

        # Uncertainty computer.
        self.uncertaintyComputer = ClassifiedCostVolumeEpistemic() \
            if self.uncertainty else None

        # Must be called at the end of __init__().
        self.update_freeze()
Beispiel #29
0
    def __init__(self):
        super(ClassifiedCostVolumeEpistemic, self).__init__()

        self.flagAlignCorners = modelGLOBAL.torch_align_corners()
Beispiel #30
0
    def __init__(self, trueDispMax, flagIntNearest=False):
        super(TrueValueGenerator, self).__init__()

        self.trueDispMax      = trueDispMax
        self.flagAlignCorners = modelGLOBAL.torch_align_corners()
        self.flagIntNearest   = flagIntNearest