def v(depth, nPlanes): m = scn.Sequential() if depth == 1: for _ in range(reps): res(m, nPlanes, nPlanes, dropout_p) else: m = scn.Sequential() for _ in range(reps): res(m, nPlanes, nPlanes, dropout_p) if dropout_width: m.add(scn.ConcatTable().add(scn.Identity()).add( scn.Sequential().add(scn.BatchNormReLU(nPlanes)).add( #In place of Maxpooling scn.Convolution( dimension, nPlanes, nPlanes, 2, 2, False)).add(scn.Dropout(dropout_p)).add( v(depth - 1, nPlanes)).add( scn.BatchNormReLU(nPlanes)).add( scn.Deconvolution( dimension, nPlanes, nPlanes, 2, 2, False)))) else: m.add(scn.ConcatTable().add(scn.Identity()).add( scn.Sequential().add(scn.BatchNormReLU(nPlanes)).add( scn.Convolution(dimension, nPlanes, nPlanes, 2, 2, False)).add(v(depth - 1, nPlanes)).add( scn.BatchNormReLU(nPlanes)).add( scn.Deconvolution( dimension, nPlanes, nPlanes, 2, 2, False)))) m.add(scn.JoinTable()) for i in range(reps): res(m, 2 * nPlanes if i == 0 else nPlanes, nPlanes, dropout_p) return m
def __init__(self, cfg, name='yresnet_encoder'): super(YResNetEncoder, self).__init__(cfg, name='network_base') self.model_config = cfg[name] # YResNet Configurations # Conv block repetition factor self.reps = self.model_config.get('reps', 2) self.kernel_size = self.model_config.get('kernel_size', 2) self.num_strides = self.model_config.get('num_strides', 5) self.num_filters = self.model_config.get('filters', 16) self.nPlanes = [ i * self.num_filters for i in range(1, self.num_strides + 1) ] # [filter size, filter stride] self.downsample = [self.kernel_size, 2] dropout_prob = self.model_config.get('dropout_prob', 0.5) # Define Sparse YResNet Encoder self.encoding_block = scn.Sequential() self.encoding_conv = scn.Sequential() for i in range(self.num_strides): m = scn.Sequential() for _ in range(self.reps): self._resnet_block(m, self.nPlanes[i], self.nPlanes[i]) self.encoding_block.add(m) m = scn.Sequential() if i < self.num_strides - 1: m.add( scn.BatchNormLeakyReLU(self.nPlanes[i], leakiness=self.leakiness)).add( scn.Convolution(self.dimension, self.nPlanes[i], self.nPlanes[i+1], \ self.downsample[0], self.downsample[1], self.allow_bias)).add( scn.Dropout(p=dropout_prob)) self.encoding_conv.add(m)
def __init__(self, cfg, name='yresnet_decoder'): super(YResNetDecoder, self).__init__(cfg, name='network_base') self.model_config = cfg[name] self.reps = self.model_config.get('reps', 2) # Conv block repetition factor self.kernel_size = self.model_config.get('kernel_size', 2) self.num_strides = self.model_config.get('num_strides', 5) self.num_filters = self.model_config.get('filters', 16) self.nPlanes = [ i * self.num_filters for i in range(1, self.num_strides + 1) ] self.downsample = [self.kernel_size, 2] # [filter size, filter stride] self.concat = scn.JoinTable() self.add = scn.AddTable() dropout_prob = self.model_config.get('dropout_prob', 0.5) self.encoder_num_filters = self.model_config.get( 'encoder_num_filters', None) if self.encoder_num_filters is None: self.encoder_num_filters = self.num_filters self.encoder_nPlanes = [ i * self.encoder_num_filters for i in range(1, self.num_strides + 1) ] # Define Sparse YResNet Decoder. self.decoding_block = scn.Sequential() self.decoding_conv = scn.Sequential() for idx, i in enumerate(list(range(self.num_strides - 2, -1, -1))): if idx == 0: m = scn.Sequential().add( scn.BatchNormLeakyReLU(self.encoder_nPlanes[i + 1], leakiness=self.leakiness)).add( scn.Deconvolution( self.dimension, self.encoder_nPlanes[i + 1], self.nPlanes[i], self.downsample[0], self.downsample[1], self.allow_bias)) else: m = scn.Sequential().add( scn.BatchNormLeakyReLU( self.nPlanes[i + 1], leakiness=self.leakiness)).add( scn.Deconvolution(self.dimension, self.nPlanes[i + 1], self.nPlanes[i], self.downsample[0], self.downsample[1], self.allow_bias)).add( scn.Dropout(p=dropout_prob)) self.decoding_conv.add(m) m = scn.Sequential() for j in range(self.reps): self._resnet_block(m, self.nPlanes[i] + (self.encoder_nPlanes[i] \ if j == 0 else 0), self.nPlanes[i]) self.decoding_block.add(m)
def _res_dropout(m, a, b, p): m.add(scn.ConcatTable() .add(scn.Identity() if a == b else scn.NetworkInNetwork(a, b, False)) .add(scn.Sequential() .add(scn.BatchNormReLU(a)) .add(Dropout(p)) .add(scn.SubmanifoldConvolution(dimension, a, b, 3, False)) .add(scn.BatchNormReLU(b)) .add(scn.Dropout(p)) .add(scn.SubmanifoldConvolution(dimension, b, b, 3, False))))\ .add(scn.AddTable())