def __init__(self, in_planes, out_planes, stride): super(projfeat3d, self).__init__() self.flagTS = GLOBAL.torch_batch_normal_track_stat() self.stride = stride self.conv1 = nn.Conv2d(in_planes, out_planes, (1,1), padding=(0,0), stride=stride[:2],bias=False) self.bn = nn.BatchNorm2d(out_planes, track_running_stats=self.flagTS)
def __init__(self, inCh, trs=None): """ trs stands for track_running_stats. """ super(FeatureNormalization, self).__init__() assert inCh > 0 if (trs is None): self.model = nn.BatchNorm2d( inCh, track_running_stats=GLOBAL.torch_batch_normal_track_stat()) else: self.model = nn.BatchNorm2d(inCh, track_running_stats=trs)
def __init__(self, initialChannels=32, freeze=False): super(UNet, self).__init__(freeze=freeze) self.flagTS = GLOBAL.torch_batch_normal_track_stat() self.inplanes = initialChannels # Encoder self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16, padding=1, stride=2, bias=False) self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16, padding=1, stride=1, bias=False) self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32, padding=1, stride=1, bias=False) # Vanilla Residual Blocks self.res_block3 = self._make_layer(residualBlock,64,1,stride=2) self.res_block5 = self._make_layer(residualBlock,128,1,stride=2) self.res_block6 = self._make_layer(residualBlock,128,1,stride=2) self.res_block7 = self._make_layer(residualBlock,128,1,stride=2) self.pyramid_pooling = pyramidPooling(128, None, fusion_mode='sum', model_name='icnet') # Iconvs self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2), conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, padding=1, stride=1, bias=False)) self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, padding=1, stride=1, bias=False) self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2), conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, padding=1, stride=1, bias=False)) self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, padding=1, stride=1, bias=False) self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2), conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, padding=1, stride=1, bias=False)) self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, padding=1, stride=1, bias=False) self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=32, padding=0,stride=1,bias=False) self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=16, padding=0,stride=1,bias=False) self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=16, padding=0,stride=1,bias=False) self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=16, padding=0,stride=1,bias=False) # Must be called at the end of __init__(). self.update_freeze()
def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, with_bn=True): super(conv2DBatchNorm, self).__init__() self.flagTS = GLOBAL.torch_batch_normal_track_stat() if dilation > 1: conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=dilation) else: conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=1) if with_bn: self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters), track_running_stats=self.flagTS),) else: self.cb_unit = nn.Sequential(conv_mod,)
def disable_batch_norm_track_running_stat(self): self.check_frame() self.frame.logger.info('Use track_running_stats=False.') modelGLOBAL.torch_batch_normal_track_stat(False)
def __init__(self, inCh, trs=None): super(FeatureNorm3D, self).__init__() self.model = nn.BatchNorm3d( inCh, track_running_stats=GLOBAL.torch_batch_normal_track_stat() ) \ if trs is None \ else nn.BatchNorm3d( inCh, track_running_stats=trs )
def __init__(self, edChannels = [ [ 3, 16, 16, 16, 16 ], [ 16, 16, 16, 16, 16 ], [ 16, 16, 16, 16, 16 ], [ 16, 16, 16, 16, 16 ] ], freeze=False): ''' edChannels (list of lists): Channel specification for every layer. [ eIn, eOut, dOut, up, out ] ''' super(UNetOneHalf, self).__init__( levels=[2, 4, 8, 16], freeze=freeze) N = len(levels) assert( N == len(edChannels) ), \ f'Wrong level and channel specification. levels = {levels}, edChannels = {edChannels}' self.flagTS = GLOBAL.torch_batch_normal_track_stat() self.flagReLUInplace = GLOBAL.torch_relu_inplace() # Encoders. self.encoders = nn.ModuleList() for i in range(N): inCh = edChannels[i][CH_IDX_E_IN] outCh = edChannels[i][CH_IDX_E_OUT] self.encoders.append( cm.ResidualBlock( inCh, outCh, stride=2, downsample= cm.Conv( inCh, outCh, k=1, s=2, p=0, normLayer=cm.FeatureNormalization( outCh ) ) ) ) # Decoders. self.decoders = nn.ModuleList() for i in range(N-1, -1, -1): if ( i == N-1 ): inCh = edChannels[i][CH_IDX_E_OUT] else: inCh = edChannels[i][CH_IDX_E_OUT] \ + edChannels[i-1][CH_IDX_U_OUT] outCh = edChannels[i][CH_IDX_D_OUT] self.decoders.append( cm.Conv_W( inCh, outCh, normLayer=cm.FeatureNormalization( outCh ), activation=nn.ReLU(inplace=self.flagReLUInplace) ) ) self.decoders = self.decoders[::-1] # Up-feature layers. self.ups = nn.ModuleList() for i in range( N-1, 0, -1 ): inCh = edChannels[i][CH_IDX_D_OUT] outCh = edChannels[i][CH_IDX_U_OUT] self.ups.append( cm.Interpolate2D_FixedScale(2), cm.Conv_W( inCh, outCh, normLayer=cm.FeatureNormalization( outCh ), activation=nn.ReLU(inplace=self.flagReLUInplace) ) ) self.ups = self.ups[::-1] # Finale layers. self.finals = nn.ModuleList() for i in range( N-1, -1, -1 ): inCh = edChannels[i][CH_IDX_U_OUT] outCh = edChannels[i][CH_IDX_F_OUT] self.finals.append( cm.Conv_W( inCh, outCh, normLayer=cm.FeatureNormalization( outCh ), activation=nn.ReLU(inplace=self.flagReLUInplace) ) ) self.finals = self.finals[::-1] # Middle. self.middle = cm.ResidualBlock( edChannels[-1][CH_IDX_E_OUT], edChannels[-1][CH_IDX_E_OUT], lastActivation=nn.ReLU(inplace=self.flagReLUInplace) )