def norm_lrelu_upscale_conv_norm_lrelu(l_in, feat_out):
     if do_norm:
         l_in = BatchNormLayer(l_in, axes=axes)
     l = NonlinearityLayer(l_in, nonlin)
     l = Upscale3DLayer(l, 2)
     l = Conv3DLayer(l,
                     feat_out,
                     3,
                     1,
                     'same',
                     nonlinearity=linear,
                     W=HeNormal(gain='relu'))
     if do_norm:
         l = BatchNormLayer(l, axes=axes)
     l = NonlinearityLayer(l, nonlin)
     return l
Exemplo n.º 2
0
def Bilinear_3DInterpolation(incoming,
                             upscale_factor,
                             untie_biases=False,
                             nonlinearity=None,
                             pad='same'):
    """ 3Dunpool + 3DConv with fixed filters 
    In order to support multi-channel bilinear interpolation without extra effort, we can simply reshape it into 1-channel feature maps
    before do the interpolation followed with another reshape Layer.
    """
    unpooledLayer = Upscale3DLayer(
        incoming, upscale_factor, mode='dilate'
    )  # new api from lasagne, Unpool3DLayer(incoming, upscale_factor) # old API
    k_size = upscale_factor / 2 * 2 + 1

    unpooledLayer_1channel = ReshapeLayer(unpooledLayer,
                                          shape=(-1, 1) +
                                          unpooledLayer.output_shape[-3:])
    deconvedLayer = Conv3DDNNLayer(unpooledLayer_1channel,1,(k_size,k_size,k_size),nonlinearity=nonlinearity,\
                                   untie_biases=untie_biases,pad=pad,b=None,W=__W_5D__(k_size))
    deconvedLayer.params[deconvedLayer.W].remove('trainable')

    return ReshapeLayer(deconvedLayer,
                        shape=(-1, ) + unpooledLayer.output_shape[1:])
    l = MaxPool3DLayer(l, pool_size=2, name='maxpool')
    l = Conv3DLayer(l,
                    num_filters=96,
                    filter_size=(3, 3, 3),
                    pad='same',
                    name="conv",
                    nonlinearity=elu)
    l = Conv3DLayer(l,
                    num_filters=64,
                    filter_size=(3, 3, 3),
                    pad='same',
                    name='conv',
                    nonlinearity=elu)
    l = batch_norm(l)

    l = Upscale3DLayer(l, scale_factor=2, name="upscale")
    l = Conv3DLayer(l,
                    num_filters=32,
                    filter_size=(3, 3, 3),
                    pad='same',
                    name="conv",
                    nonlinearity=elu)
    l = ConcatLayer([l, li2])
    l = Conv3DLayer(l,
                    num_filters=32,
                    filter_size=(3, 3, 3),
                    pad='same',
                    name='conv',
                    nonlinearity=elu)
    l = batch_norm(l)
    def build_network(self):
        self.input_var = tensor5()
        self.output_var = matrix()
        net = OrderedDict()
        if self.instance_norm:
            norm_fct = batch_norm
            norm_kwargs = {'axes': (2, 3, 4)}
        else:
            norm_fct = batch_norm
            norm_kwargs = {'axes': 'auto'}

        self.input_layer = net['input'] = InputLayer(
            (self.batch_size, self.n_input_channels, self.input_dim[0],
             self.input_dim[1], self.input_dim[2]), self.input_var)

        net['contr_1_1'] = norm_fct(
            Conv3DLayer(net['input'],
                        self.base_n_filters,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['contr_1_2'] = norm_fct(
            Conv3DLayer(net['contr_1_1'],
                        self.base_n_filters,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['pool1'] = Pool3DLayer(net['contr_1_2'], (1, 2, 2))

        net['contr_2_1'] = norm_fct(
            Conv3DLayer(net['pool1'],
                        self.base_n_filters * 2,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['contr_2_2'] = norm_fct(
            Conv3DLayer(net['contr_2_1'],
                        self.base_n_filters * 2,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        l = net['pool2'] = Pool3DLayer(net['contr_2_2'], (1, 2, 2))
        if self.dropout is not None:
            l = DropoutLayer(l, p=self.dropout)

        net['contr_3_1'] = norm_fct(
            Conv3DLayer(l,
                        self.base_n_filters * 4,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['contr_3_2'] = norm_fct(
            Conv3DLayer(net['contr_3_1'],
                        self.base_n_filters * 4,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        l = net['pool3'] = Pool3DLayer(net['contr_3_2'], (1, 2, 2))
        if self.dropout is not None:
            l = DropoutLayer(l, p=self.dropout)

        net['contr_4_1'] = norm_fct(
            Conv3DLayer(l,
                        self.base_n_filters * 8,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['contr_4_2'] = norm_fct(
            Conv3DLayer(net['contr_4_1'],
                        self.base_n_filters * 8,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        l = net['pool4'] = Pool3DLayer(net['contr_4_2'], (1, 2, 2))
        if self.dropout is not None:
            l = DropoutLayer(l, p=self.dropout)

        net['encode_1'] = norm_fct(
            Conv3DLayer(l,
                        self.base_n_filters * 16,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        l = net['encode_2'] = norm_fct(
            Conv3DLayer(net['encode_1'],
                        self.base_n_filters * 16,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['upscale1'] = Upscale3DLayer(l, (1, 2, 2))

        l = net['concat1'] = ConcatLayer([net['upscale1'], net['contr_4_2']],
                                         cropping=(None, None, "center",
                                                   "center", "center"))
        if self.dropout is not None:
            l = DropoutLayer(l, p=self.dropout)
        net['expand_1_1'] = norm_fct(
            Conv3DLayer(l,
                        self.base_n_filters * 8,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        l = net['expand_1_2'] = norm_fct(
            Conv3DLayer(net['expand_1_1'],
                        self.base_n_filters * 8,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['upscale2'] = Upscale3DLayer(l, (1, 2, 2))

        l = net['concat2'] = ConcatLayer([net['upscale2'], net['contr_3_2']],
                                         cropping=(None, None, "center",
                                                   "center", "center"))
        if self.dropout is not None:
            l = DropoutLayer(l, p=self.dropout)
        net['expand_2_1'] = norm_fct(
            Conv3DLayer(l,
                        self.base_n_filters * 4,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        ds2 = l = net['expand_2_2'] = norm_fct(
            Conv3DLayer(net['expand_2_1'],
                        self.base_n_filters * 4,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['upscale3'] = Upscale3DLayer(l, (1, 2, 2))

        l = net['concat3'] = ConcatLayer([net['upscale3'], net['contr_2_2']],
                                         cropping=(None, None, "center",
                                                   "center", "center"))
        if self.dropout is not None:
            l = DropoutLayer(l, p=self.dropout)
        net['expand_3_1'] = norm_fct(
            Conv3DLayer(l,
                        self.base_n_filters * 2,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        l = net['expand_3_2'] = norm_fct(
            Conv3DLayer(net['expand_3_1'],
                        self.base_n_filters * 2,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['upscale4'] = Upscale3DLayer(l, (1, 2, 2))

        net['concat4'] = ConcatLayer([net['upscale4'], net['contr_1_2']],
                                     cropping=(None, None, "center", "center",
                                               "center"))
        net['expand_4_1'] = norm_fct(
            Conv3DLayer(net['concat4'],
                        self.base_n_filters,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)
        net['expand_4_2'] = norm_fct(
            Conv3DLayer(net['expand_4_1'],
                        self.base_n_filters,
                        3,
                        nonlinearity=self.nonlinearity,
                        pad=self.pad,
                        W=lasagne.init.HeNormal(gain="relu")), **norm_kwargs)

        net['output_segmentation'] = Conv3DLayer(net['expand_4_2'],
                                                 self.num_classes,
                                                 1,
                                                 nonlinearity=None)

        ds2_1x1_conv = Conv3DLayer(ds2,
                                   self.num_classes,
                                   1,
                                   1,
                                   'same',
                                   nonlinearity=lasagne.nonlinearities.linear,
                                   W=lasagne.init.HeNormal(gain='relu'))
        ds1_ds2_sum_upscale = Upscale3DLayer(ds2_1x1_conv, (1, 2, 2))
        ds3_1x1_conv = Conv3DLayer(net['expand_3_2'],
                                   self.num_classes,
                                   1,
                                   1,
                                   'same',
                                   nonlinearity=lasagne.nonlinearities.linear,
                                   W=lasagne.init.HeNormal(gain='relu'))
        ds1_ds2_sum_upscale_ds3_sum = ElemwiseSumLayer(
            (ds1_ds2_sum_upscale, ds3_1x1_conv))
        ds1_ds2_sum_upscale_ds3_sum_upscale = Upscale3DLayer(
            ds1_ds2_sum_upscale_ds3_sum, (1, 2, 2))

        self.seg_layer = l = ElemwiseSumLayer(
            (net['output_segmentation'], ds1_ds2_sum_upscale_ds3_sum_upscale))

        net['dimshuffle'] = DimshuffleLayer(l, (0, 2, 3, 4, 1))
        batch_size, n_z, n_rows, n_cols, _ = lasagne.layers.get_output(
            net['dimshuffle']).shape
        net['reshapeSeg'] = ReshapeLayer(
            net['dimshuffle'],
            (batch_size * n_rows * n_cols * n_z, self.num_classes))
        self.output_layer = net['output_flattened'] = NonlinearityLayer(
            net['reshapeSeg'], nonlinearity=lasagne.nonlinearities.softmax)
def build_net(input_var=None,
              input_shape=(128, 128, 128),
              num_output_classes=4,
              num_input_channels=4,
              base_n_filter=8,
              do_instance_norm=True,
              batch_size=None,
              dropout_p=0.3,
              do_norm=True):
    nonlin = lasagne.nonlinearities.leaky_rectify
    if do_instance_norm:
        axes = (2, 3, 4)
    else:
        axes = 'auto'

    def conv_norm_lrelu(l_in, feat_out):
        l = Conv3DLayer(l_in,
                        feat_out,
                        3,
                        1,
                        'same',
                        nonlinearity=linear,
                        W=HeNormal(gain='relu'))
        if do_norm:
            l = BatchNormLayer(l, axes=axes)
        return NonlinearityLayer(l, nonlin)

    def norm_lrelu_conv(l_in, feat_out, stride=1, filter_size=3):
        if do_norm:
            l_in = BatchNormLayer(l_in, axes=axes)
        l = NonlinearityLayer(l_in, nonlin)
        return Conv3DLayer(l,
                           feat_out,
                           filter_size,
                           stride,
                           'same',
                           nonlinearity=linear,
                           W=HeNormal(gain='relu'))

    def lrelu_conv(l_in, feat_out, stride=1, filter_size=3):
        l = NonlinearityLayer(l_in, nonlin)
        return Conv3DLayer(l,
                           feat_out,
                           filter_size,
                           stride,
                           'same',
                           nonlinearity=linear,
                           W=HeNormal(gain='relu'))

    def norm_lrelu_upscale_conv_norm_lrelu(l_in, feat_out):
        if do_norm:
            l_in = BatchNormLayer(l_in, axes=axes)
        l = NonlinearityLayer(l_in, nonlin)
        l = Upscale3DLayer(l, 2)
        l = Conv3DLayer(l,
                        feat_out,
                        3,
                        1,
                        'same',
                        nonlinearity=linear,
                        W=HeNormal(gain='relu'))
        if do_norm:
            l = BatchNormLayer(l, axes=axes)
        l = NonlinearityLayer(l, nonlin)
        return l

    l_in = InputLayer(shape=(batch_size, num_input_channels, input_shape[0],
                             input_shape[1], input_shape[2]),
                      input_var=input_var)

    l = r = Conv3DLayer(l_in,
                        num_filters=base_n_filter,
                        filter_size=3,
                        stride=1,
                        nonlinearity=linear,
                        pad='same',
                        W=HeNormal(gain='relu'))
    l = NonlinearityLayer(l, nonlin)
    l = Conv3DLayer(l,
                    num_filters=base_n_filter,
                    filter_size=3,
                    stride=1,
                    nonlinearity=linear,
                    pad='same',
                    W=HeNormal(gain='relu'))
    l = DropoutLayer(l, dropout_p)
    l = lrelu_conv(l, base_n_filter, 1, 3)
    l = ElemwiseSumLayer((l, r))
    skip1 = NonlinearityLayer(l, nonlin)
    if do_norm:
        l = BatchNormLayer(l, axes=axes)
    l = NonlinearityLayer(l, nonlin)

    l = r = Conv3DLayer(l,
                        base_n_filter * 2,
                        3,
                        2,
                        'same',
                        nonlinearity=linear,
                        W=HeNormal(gain='relu'))
    l = norm_lrelu_conv(l, base_n_filter * 2)
    l = DropoutLayer(l, dropout_p)
    l = norm_lrelu_conv(l, base_n_filter * 2)
    l = ElemwiseSumLayer((l, r))
    if do_norm:
        l = BatchNormLayer(l, axes=axes)
    l = skip2 = NonlinearityLayer(l, nonlin)

    l = r = Conv3DLayer(l,
                        base_n_filter * 4,
                        3,
                        2,
                        'same',
                        nonlinearity=linear,
                        W=HeNormal(gain='relu'))
    l = norm_lrelu_conv(l, base_n_filter * 4)
    l = DropoutLayer(l, dropout_p)
    l = norm_lrelu_conv(l, base_n_filter * 4)
    l = ElemwiseSumLayer((l, r))
    if do_norm:
        l = BatchNormLayer(l, axes=axes)
    l = skip3 = NonlinearityLayer(l, nonlin)

    l = r = Conv3DLayer(l,
                        base_n_filter * 8,
                        3,
                        2,
                        'same',
                        nonlinearity=linear,
                        W=HeNormal(gain='relu'))
    l = norm_lrelu_conv(l, base_n_filter * 8)
    l = DropoutLayer(l, dropout_p)
    l = norm_lrelu_conv(l, base_n_filter * 8)
    l = ElemwiseSumLayer((l, r))
    if do_norm:
        l = BatchNormLayer(l, axes=axes)
    l = skip4 = NonlinearityLayer(l, nonlin)

    l = r = Conv3DLayer(l,
                        base_n_filter * 16,
                        3,
                        2,
                        'same',
                        nonlinearity=linear,
                        W=HeNormal(gain='relu'))
    l = norm_lrelu_conv(l, base_n_filter * 16)
    l = DropoutLayer(l, dropout_p)
    l = norm_lrelu_conv(l, base_n_filter * 16)
    l = ElemwiseSumLayer((l, r))
    l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter * 8)

    l = Conv3DLayer(l,
                    base_n_filter * 8,
                    1,
                    1,
                    'same',
                    nonlinearity=linear,
                    W=HeNormal(gain='relu'))
    if do_norm:
        l = BatchNormLayer(l, axes=axes)
    l = NonlinearityLayer(l, nonlin)

    l = ConcatLayer((skip4, l), cropping=[None, None, 'center', 'center'])
    l = conv_norm_lrelu(l, base_n_filter * 16)
    l = Conv3DLayer(l,
                    base_n_filter * 8,
                    1,
                    1,
                    'same',
                    nonlinearity=linear,
                    W=HeNormal(gain='relu'))
    l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter * 4)

    l = ConcatLayer((skip3, l), cropping=[None, None, 'center', 'center'])
    l = ds2 = conv_norm_lrelu(l, base_n_filter * 8)
    l = Conv3DLayer(l,
                    base_n_filter * 4,
                    1,
                    1,
                    'same',
                    nonlinearity=linear,
                    W=HeNormal(gain='relu'))
    l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter * 2)

    l = ConcatLayer((skip2, l), cropping=[None, None, 'center', 'center'])
    l = ds3 = conv_norm_lrelu(l, base_n_filter * 4)
    l = Conv3DLayer(l,
                    base_n_filter * 2,
                    1,
                    1,
                    'same',
                    nonlinearity=linear,
                    W=HeNormal(gain='relu'))
    l = norm_lrelu_upscale_conv_norm_lrelu(l, base_n_filter)

    l = ConcatLayer((skip1, l), cropping=[None, None, 'center', 'center'])
    l = conv_norm_lrelu(l, base_n_filter * 2)
    l_pred = Conv3DLayer(l,
                         num_output_classes,
                         1,
                         pad='same',
                         nonlinearity=None)

    ds2_1x1_conv = Conv3DLayer(ds2,
                               num_output_classes,
                               1,
                               1,
                               'same',
                               nonlinearity=linear,
                               W=HeNormal(gain='relu'))
    ds1_ds2_sum_upscale = Upscale3DLayer(ds2_1x1_conv, 2)
    ds3_1x1_conv = Conv3DLayer(ds3,
                               num_output_classes,
                               1,
                               1,
                               'same',
                               nonlinearity=linear,
                               W=HeNormal(gain='relu'))
    ds1_ds2_sum_upscale_ds3_sum = ElemwiseSumLayer(
        (ds1_ds2_sum_upscale, ds3_1x1_conv))
    ds1_ds2_sum_upscale_ds3_sum_upscale = Upscale3DLayer(
        ds1_ds2_sum_upscale_ds3_sum, 2)

    l = seg_layer = ElemwiseSumLayer(
        (l_pred, ds1_ds2_sum_upscale_ds3_sum_upscale))
    l = DimshuffleLayer(l, (0, 2, 3, 4, 1))
    batch_size, n_rows, n_cols, n_z, _ = lasagne.layers.get_output(l).shape
    l = ReshapeLayer(l,
                     (batch_size * n_rows * n_cols * n_z, num_output_classes))
    l = NonlinearityLayer(l, nonlinearity=lasagne.nonlinearities.softmax)
    return l, seg_layer