def norm_lrelu_upscale_conv_norm_lrelu(l_in, feat_out): if do_norm: l_in = BatchNormLayer(l_in, axes=axes) l = NonlinearityLayer(l_in, nonlin) l = Upscale3DLayer(l, 2) l = Conv3DLayer(l, feat_out, 3, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) if do_norm: l = BatchNormLayer(l, axes=axes) l = NonlinearityLayer(l, nonlin) return l
def Bilinear_3DInterpolation(incoming, upscale_factor, untie_biases=False, nonlinearity=None, pad='same'): """ 3Dunpool + 3DConv with fixed filters In order to support multi-channel bilinear interpolation without extra effort, we can simply reshape it into 1-channel feature maps before do the interpolation followed with another reshape Layer. """ unpooledLayer = Upscale3DLayer( incoming, upscale_factor, mode='dilate' ) # new api from lasagne, Unpool3DLayer(incoming, upscale_factor) # old API k_size = upscale_factor / 2 * 2 + 1 unpooledLayer_1channel = ReshapeLayer(unpooledLayer, shape=(-1, 1) + unpooledLayer.output_shape[-3:]) deconvedLayer = Conv3DDNNLayer(unpooledLayer_1channel,1,(k_size,k_size,k_size),nonlinearity=nonlinearity,\ untie_biases=untie_biases,pad=pad,b=None,W=__W_5D__(k_size)) deconvedLayer.params[deconvedLayer.W].remove('trainable') return ReshapeLayer(deconvedLayer, shape=(-1, ) + unpooledLayer.output_shape[1:])
l = MaxPool3DLayer(l, pool_size=2, name='maxpool') l = Conv3DLayer(l, num_filters=96, filter_size=(3, 3, 3), pad='same', name="conv", nonlinearity=elu) l = Conv3DLayer(l, num_filters=64, filter_size=(3, 3, 3), pad='same', name='conv', nonlinearity=elu) l = batch_norm(l) l = Upscale3DLayer(l, scale_factor=2, name="upscale") l = Conv3DLayer(l, num_filters=32, filter_size=(3, 3, 3), pad='same', name="conv", nonlinearity=elu) l = ConcatLayer([l, li2]) l = Conv3DLayer(l, num_filters=32, filter_size=(3, 3, 3), pad='same', name='conv', nonlinearity=elu) l = batch_norm(l)
def build_network(self): self.input_var = tensor5() self.output_var = matrix() net = OrderedDict() if self.instance_norm: norm_fct = batch_norm norm_kwargs = {'axes': (2, 3, 4)} else: norm_fct = batch_norm norm_kwargs = {'axes': 'auto'} self.input_layer = net['input'] = InputLayer( (self.batch_size, self.n_input_channels, self.input_dim[0], self.input_dim[1], self.input_dim[2]), self.input_var) net['contr_1_1'] = norm_fct( Conv3DLayer(net['input'], self.base_n_filters, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['contr_1_2'] = norm_fct( Conv3DLayer(net['contr_1_1'], self.base_n_filters, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['pool1'] = Pool3DLayer(net['contr_1_2'], (1, 2, 2)) net['contr_2_1'] = norm_fct( Conv3DLayer(net['pool1'], self.base_n_filters * 2, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['contr_2_2'] = norm_fct( Conv3DLayer(net['contr_2_1'], self.base_n_filters * 2, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) l = net['pool2'] = Pool3DLayer(net['contr_2_2'], (1, 2, 2)) if self.dropout is not None: l = DropoutLayer(l, p=self.dropout) net['contr_3_1'] = norm_fct( Conv3DLayer(l, self.base_n_filters * 4, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['contr_3_2'] = norm_fct( Conv3DLayer(net['contr_3_1'], self.base_n_filters * 4, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) l = net['pool3'] = Pool3DLayer(net['contr_3_2'], (1, 2, 2)) if self.dropout is not None: l = DropoutLayer(l, p=self.dropout) net['contr_4_1'] = norm_fct( Conv3DLayer(l, self.base_n_filters * 8, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['contr_4_2'] = norm_fct( Conv3DLayer(net['contr_4_1'], self.base_n_filters * 8, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) l = net['pool4'] = Pool3DLayer(net['contr_4_2'], (1, 2, 2)) if self.dropout is not None: l = DropoutLayer(l, p=self.dropout) net['encode_1'] = norm_fct( Conv3DLayer(l, self.base_n_filters * 16, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) l = net['encode_2'] = norm_fct( Conv3DLayer(net['encode_1'], self.base_n_filters * 16, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['upscale1'] = Upscale3DLayer(l, (1, 2, 2)) l = net['concat1'] = ConcatLayer([net['upscale1'], net['contr_4_2']], cropping=(None, None, "center", "center", "center")) if self.dropout is not None: l = DropoutLayer(l, p=self.dropout) net['expand_1_1'] = norm_fct( Conv3DLayer(l, self.base_n_filters * 8, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) l = net['expand_1_2'] = norm_fct( Conv3DLayer(net['expand_1_1'], self.base_n_filters * 8, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['upscale2'] = Upscale3DLayer(l, (1, 2, 2)) l = net['concat2'] = ConcatLayer([net['upscale2'], net['contr_3_2']], cropping=(None, None, "center", "center", "center")) if self.dropout is not None: l = DropoutLayer(l, p=self.dropout) net['expand_2_1'] = norm_fct( Conv3DLayer(l, self.base_n_filters * 4, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) ds2 = l = net['expand_2_2'] = norm_fct( Conv3DLayer(net['expand_2_1'], self.base_n_filters * 4, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['upscale3'] = Upscale3DLayer(l, (1, 2, 2)) l = net['concat3'] = ConcatLayer([net['upscale3'], net['contr_2_2']], cropping=(None, None, "center", "center", "center")) if self.dropout is not None: l = DropoutLayer(l, p=self.dropout) net['expand_3_1'] = norm_fct( Conv3DLayer(l, self.base_n_filters * 2, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) l = net['expand_3_2'] = norm_fct( Conv3DLayer(net['expand_3_1'], self.base_n_filters * 2, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['upscale4'] = Upscale3DLayer(l, (1, 2, 2)) net['concat4'] = ConcatLayer([net['upscale4'], net['contr_1_2']], cropping=(None, None, "center", "center", "center")) net['expand_4_1'] = norm_fct( Conv3DLayer(net['concat4'], self.base_n_filters, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['expand_4_2'] = norm_fct( Conv3DLayer(net['expand_4_1'], self.base_n_filters, 3, nonlinearity=self.nonlinearity, pad=self.pad, W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs) net['output_segmentation'] = Conv3DLayer(net['expand_4_2'], self.num_classes, 1, nonlinearity=None) ds2_1x1_conv = Conv3DLayer(ds2, self.num_classes, 1, 1, 'same', nonlinearity=lasagne.nonlinearities.linear, W=lasagne.init.HeNormal(gain='relu')) ds1_ds2_sum_upscale = Upscale3DLayer(ds2_1x1_conv, (1, 2, 2)) ds3_1x1_conv = Conv3DLayer(net['expand_3_2'], self.num_classes, 1, 1, 'same', nonlinearity=lasagne.nonlinearities.linear, W=lasagne.init.HeNormal(gain='relu')) ds1_ds2_sum_upscale_ds3_sum = ElemwiseSumLayer( (ds1_ds2_sum_upscale, ds3_1x1_conv)) ds1_ds2_sum_upscale_ds3_sum_upscale = Upscale3DLayer( ds1_ds2_sum_upscale_ds3_sum, (1, 2, 2)) self.seg_layer = l = ElemwiseSumLayer( (net['output_segmentation'], ds1_ds2_sum_upscale_ds3_sum_upscale)) net['dimshuffle'] = DimshuffleLayer(l, (0, 2, 3, 4, 1)) batch_size, n_z, n_rows, n_cols, _ = lasagne.layers.get_output( net['dimshuffle']).shape net['reshapeSeg'] = ReshapeLayer( net['dimshuffle'], (batch_size * n_rows * n_cols * n_z, self.num_classes)) self.output_layer = net['output_flattened'] = NonlinearityLayer( net['reshapeSeg'], nonlinearity=lasagne.nonlinearities.softmax)
def build_net(input_var=None, input_shape=(128, 128, 128), num_output_classes=4, num_input_channels=4, base_n_filter=8, do_instance_norm=True, batch_size=None, dropout_p=0.3, do_norm=True): nonlin = lasagne.nonlinearities.leaky_rectify if do_instance_norm: axes = (2, 3, 4) else: axes = 'auto' def conv_norm_lrelu(l_in, feat_out): l = Conv3DLayer(l_in, feat_out, 3, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) if do_norm: l = BatchNormLayer(l, axes=axes) return NonlinearityLayer(l, nonlin) def norm_lrelu_conv(l_in, feat_out, stride=1, filter_size=3): if do_norm: l_in = BatchNormLayer(l_in, axes=axes) l = NonlinearityLayer(l_in, nonlin) return Conv3DLayer(l, feat_out, filter_size, stride, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) def lrelu_conv(l_in, feat_out, stride=1, filter_size=3): l = NonlinearityLayer(l_in, nonlin) return Conv3DLayer(l, feat_out, filter_size, stride, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) def norm_lrelu_upscale_conv_norm_lrelu(l_in, feat_out): if do_norm: l_in = BatchNormLayer(l_in, axes=axes) l = NonlinearityLayer(l_in, nonlin) l = Upscale3DLayer(l, 2) l = Conv3DLayer(l, feat_out, 3, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) if do_norm: l = BatchNormLayer(l, axes=axes) l = NonlinearityLayer(l, nonlin) return l l_in = InputLayer(shape=(batch_size, num_input_channels, input_shape[0], input_shape[1], input_shape[2]), input_var=input_var) l = r = Conv3DLayer(l_in, num_filters=base_n_filter, filter_size=3, stride=1, nonlinearity=linear, pad='same', W=HeNormal(gain='relu')) l = NonlinearityLayer(l, nonlin) l = Conv3DLayer(l, num_filters=base_n_filter, filter_size=3, stride=1, nonlinearity=linear, pad='same', W=HeNormal(gain='relu')) l = DropoutLayer(l, dropout_p) l = lrelu_conv(l, base_n_filter, 1, 3) l = ElemwiseSumLayer((l, r)) skip1 = NonlinearityLayer(l, nonlin) if do_norm: l = BatchNormLayer(l, axes=axes) l = NonlinearityLayer(l, nonlin) l = r = Conv3DLayer(l, base_n_filter * 2, 3, 2, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) l = norm_lrelu_conv(l, base_n_filter * 2) l = DropoutLayer(l, dropout_p) l = norm_lrelu_conv(l, base_n_filter * 2) l = ElemwiseSumLayer((l, r)) if do_norm: l = BatchNormLayer(l, axes=axes) l = skip2 = NonlinearityLayer(l, nonlin) l = r = Conv3DLayer(l, base_n_filter * 4, 3, 2, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) l = norm_lrelu_conv(l, base_n_filter * 4) l = DropoutLayer(l, dropout_p) l = norm_lrelu_conv(l, base_n_filter * 4) l = ElemwiseSumLayer((l, r)) if do_norm: l = BatchNormLayer(l, axes=axes) l = skip3 = NonlinearityLayer(l, nonlin) l = r = Conv3DLayer(l, base_n_filter * 8, 3, 2, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) l = norm_lrelu_conv(l, base_n_filter * 8) l = DropoutLayer(l, dropout_p) l = norm_lrelu_conv(l, base_n_filter * 8) l = ElemwiseSumLayer((l, r)) if do_norm: l = BatchNormLayer(l, axes=axes) l = skip4 = NonlinearityLayer(l, nonlin) l = r = Conv3DLayer(l, base_n_filter * 16, 3, 2, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) l = norm_lrelu_conv(l, base_n_filter * 16) l = DropoutLayer(l, dropout_p) l = norm_lrelu_conv(l, base_n_filter * 16) l = ElemwiseSumLayer((l, r)) l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter * 8) l = Conv3DLayer(l, base_n_filter * 8, 1, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) if do_norm: l = BatchNormLayer(l, axes=axes) l = NonlinearityLayer(l, nonlin) l = ConcatLayer((skip4, l), cropping=[None, None, 'center', 'center']) l = conv_norm_lrelu(l, base_n_filter * 16) l = Conv3DLayer(l, base_n_filter * 8, 1, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter * 4) l = ConcatLayer((skip3, l), cropping=[None, None, 'center', 'center']) l = ds2 = conv_norm_lrelu(l, base_n_filter * 8) l = Conv3DLayer(l, base_n_filter * 4, 1, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter * 2) l = ConcatLayer((skip2, l), cropping=[None, None, 'center', 'center']) l = ds3 = conv_norm_lrelu(l, base_n_filter * 4) l = Conv3DLayer(l, base_n_filter * 2, 1, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter) l = ConcatLayer((skip1, l), cropping=[None, None, 'center', 'center']) l = conv_norm_lrelu(l, base_n_filter * 2) l_pred = Conv3DLayer(l, num_output_classes, 1, pad='same', nonlinearity=None) ds2_1x1_conv = Conv3DLayer(ds2, num_output_classes, 1, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) ds1_ds2_sum_upscale = Upscale3DLayer(ds2_1x1_conv, 2) ds3_1x1_conv = Conv3DLayer(ds3, num_output_classes, 1, 1, 'same', nonlinearity=linear, W=HeNormal(gain='relu')) ds1_ds2_sum_upscale_ds3_sum = ElemwiseSumLayer( (ds1_ds2_sum_upscale, ds3_1x1_conv)) ds1_ds2_sum_upscale_ds3_sum_upscale = Upscale3DLayer( ds1_ds2_sum_upscale_ds3_sum, 2) l = seg_layer = ElemwiseSumLayer( (l_pred, ds1_ds2_sum_upscale_ds3_sum_upscale)) l = DimshuffleLayer(l, (0, 2, 3, 4, 1)) batch_size, n_rows, n_cols, n_z, _ = lasagne.layers.get_output(l).shape l = ReshapeLayer(l, (batch_size * n_rows * n_cols * n_z, num_output_classes)) l = NonlinearityLayer(l, nonlinearity=lasagne.nonlinearities.softmax) return l, seg_layer