def __init__(self, layercount=18, inchannel=3, outchannel=64, first_initpad=False, pooling=True): super(SRResirevBlock, self).__init__() self.inchannel = inchannel self.block_list = nn.ModuleList() self.psi = psi(2) self.pooling = pooling self.first = True # for channel, depth, stride in zip(nChannels, nBlocks, nStrides): # strides = strides + ([stride] + [1] * (depth - 1)) # channels = channels + ([channel] * depth) # for channel, stride in zip(channels, strides): # block_list.append(_block(in_ch, channel, stride, # first=self.first, # dropout_rate=dropout_rate, # affineBN=affineBN, mult=mult)) for i in range(layercount): layer = irevnet_block(inchannel, outchannel, first=self.first) self.block_list.append(layer) inchannel = 2 * outchannel self.first = False
def __init__(self, in_ch, out_ch, stride=1, first=False, featureIncrease=False, dropout_rate=0., affineBN=True, mult=3, first_initpad=False): """ buid invertible bottleneck block """ super(irevnet_block, self).__init__() self.first = first self.featureIncrease = featureIncrease if self.featureIncrease: self.CI = ConcatPad(in_ch // 2, in_ch) self.stride = stride self.psi = psi(stride) self.inpx = None layers = [] if not first: layers.append(nn.ReLU(inplace=True)) layers.append( nn.Conv2d(in_ch // 2, int(out_ch // mult), kernel_size=3, stride=stride, padding=1, bias=False)) layers.append(nn.ReLU(inplace=True)) layers.append( nn.Conv2d(int(out_ch // mult), int(out_ch // mult), kernel_size=3, padding=1, bias=False)) layers.append(nn.ReLU(inplace=True)) layers.append( nn.Conv2d(int(out_ch // mult), out_ch, kernel_size=3, padding=1, bias=False)) self.bottleneck_block = nn.Sequential(*layers)
def __init__(self, nBlocks, nStrides, nClasses, nChannels=None, init_ds=2, dropout_rate=0., affineBN=True, in_shape=None, mult=4): super(iRevNet, self).__init__() self.ds = in_shape[2] // 2**(nStrides.count(2) + init_ds // 2) self.init_ds = init_ds self.in_ch = in_shape[0] * 2**self.init_ds self.nBlocks = nBlocks self.first = True print('') print(' == Building iRevNet %d == ' % (sum(nBlocks) * 3 + 1)) if not nChannels: nChannels = [ self.in_ch // 2, self.in_ch // 2 * 4, self.in_ch // 2 * 4**2, self.in_ch // 2 * 4**3 ] self.init_psi = psi(self.init_ds) self.stack = self.irevnet_stack(irevnet_block, nChannels, nBlocks, nStrides, dropout_rate=dropout_rate, affineBN=affineBN, in_ch=self.in_ch, mult=mult) self.bn1 = nn.BatchNorm2d(nChannels[-1] * 2, momentum=0.9) self.linear = nn.Linear(nChannels[-1] * 2, nClasses)