Exemple #1
0
 def conv_pool_block(l, num_filters, filter_length, i_block):
     l = DropoutLayer(l, p=self.drop_prob)
     l = Conv2DLayer(l,
                     num_filters=num_filters,
                     filter_size=[filter_length, 1],
                     nonlinearity=identity,
                     name='combined_conv_{:d}'.format(i_block))
     if self.double_time_convs:
         l = Conv2DLayer(l,
                         num_filters=num_filters,
                         filter_size=[filter_length, 1],
                         nonlinearity=identity,
                         name='combined_conv_{:d}'.format(i_block))
     if self.batch_norm:
         l = BatchNormLayer(l,
                            epsilon=1e-4,
                            alpha=self.batch_norm_alpha,
                            nonlinearity=self.later_nonlin)
     else:
         l = NonlinearityLayer(l, nonlinearity=self.later_nonlin)
     l = Pool2DLayer(l,
                     pool_size=[self.pool_time_length, 1],
                     stride=[1, 1],
                     mode=self.later_pool_mode)
     l = StrideReshapeLayer(l, n_stride=self.pool_time_stride)
     l = NonlinearityLayer(l, self.later_pool_nonlin)
     return l
Exemple #2
0
def build_pixelcnn_block(incoming, i):
    net = OrderedDict()

    nfilts = incoming.output_shape[1]  # nfilts = 2h
    net['full_deconv_A_{}'.format(i)] = Conv2DLayer(incoming,
                                                    num_filters=nfilts // 2,
                                                    filter_size=1,
                                                    name='conv_A')

    net['full_deconv_B_{}'.format(i)] = Conv2DLayer(net.values()[-1],
                                                    num_filters=nfilts // 2,
                                                    filter_size=3,
                                                    pad='same',
                                                    name='conv_B')
    f_shape = net.values()[-1].W.get_value(borrow=True).shape
    net.values()[-1].W *= get_mask(f_shape, 'B')

    net['full_deconv_C_{}'.format(i)] = Conv2DLayer(net.values()[-1],
                                                    num_filters=nfilts,
                                                    filter_size=1,
                                                    name='conv_C')

    # residual skip connection
    net['skip_{}'.format(i)] = ElemwiseMergeLayer(
        [incoming, net.values()[-1]], merge_function=T.add, name='add_convs')

    return net
Exemple #3
0
def replace_dense_softmax_by_dense_linear(all_layers, n_features,
        nonlin_before_merge, batch_norm_before_merge):
    """Replace dense/conv (n_classes) -> reshape -> softmax
    by         dense/conv (n_features) -> reshape"""
    
    reshape_layer = [l for l in all_layers if l.__class__.__name__ == 'FinalReshapeLayer']

    assert len(reshape_layer) == 1
    reshape_layer = reshape_layer[0]

    input_to_reshape = reshape_layer.input_layer
    # We expect a linear conv2d as "final dense" before the reshape...
    assert input_to_reshape.__class__.__name__ == 'Conv2DLayer', (
        "expect conv before reshape")
    assert input_to_reshape.nonlinearity.func_name == 'linear'

    # recreate with different number of filters
    assert input_to_reshape.stride == (1,1)
    new_input_to_reshape = Conv2DLayer(input_to_reshape.input_layer,
           num_filters=n_features,
            filter_size=input_to_reshape.filter_size, nonlinearity=nonlin_before_merge,
            name='final_dense')
    if batch_norm_before_merge:
        new_input_to_reshape = batch_norm(new_input_to_reshape, 
            alpha=0.1,epsilon=0.01)

    new_reshape_l = FinalReshapeLayer(new_input_to_reshape)
    return lasagne.layers.get_all_layers(new_reshape_l)
    def get_layers(self):
        l = InputLayer([None, self.in_chans, self.input_time_length, 1])
        if self.split_first_layer:
            l = DimshuffleLayer(l, pattern=[0, 3, 2, 1])
            l = Conv2DLayer(l,
                            num_filters=self.n_filters_time,
                            filter_size=[self.filter_time_length, 1],
                            nonlinearity=identity,
                            name='time_conv')
            l = Conv2DAllColsLayer(l,
                                   num_filters=self.n_filters_spat,
                                   filter_size=[1, -1],
                                   nonlinearity=identity,
                                   name='spat_conv')
        else:  #keep channel dim in first dim, so it will also be convolved over
            l = Conv2DLayer(l,
                            num_filters=self.num_filters_time,
                            filter_size=[self.filter_time_length, 1],
                            nonlinearity=identity,
                            name='time_conv')
        if self.batch_norm:
            l = BatchNormLayer(l,
                               epsilon=1e-4,
                               alpha=self.batch_norm_alpha,
                               nonlinearity=self.conv_nonlin)
        else:
            l = NonlinearityLayer(l, nonlinearity=self.conv_nonlin)

        l = Pool2DLayer(l,
                        pool_size=[self.pool_time_length, 1],
                        stride=[1, 1],
                        mode=self.pool_mode)
        l = NonlinearityLayer(l, self.pool_nonlin)
        l = StrideReshapeLayer(l, n_stride=self.pool_time_stride)
        l = DropoutLayer(l, p=self.drop_prob)

        l = Conv2DLayer(l,
                        num_filters=self.n_classes,
                        filter_size=[self.final_dense_length, 1],
                        nonlinearity=identity,
                        name='final_dense')
        l = FinalReshapeLayer(l)
        l = NonlinearityLayer(l, softmax)
        return lasagne.layers.get_all_layers(l)
def conv_bn(in_layer,
            num_filters,
            filter_size,
            nonlinearity=rectify,
            pad='same',
            name='conv'):
    """ convolution block with with batch normalization """
    in_layer = Conv2DLayer(in_layer,
                           num_filters=num_filters,
                           filter_size=filter_size,
                           nonlinearity=nonlinearity,
                           pad=pad,
                           name=name)
    in_layer = batch_norm(in_layer)
    return in_layer
Exemple #6
0
   )))




gen.append(ll.ReshapeLayer(gen[-1],shape=(batch_size,1,28,28)))
gen_out=ll.get_output(gen[-1],noisevar,deterministic=True)
gen_params=ll.get_all_params(gen)

get_image=theano.function(inputs=[noisevar],outputs=gen_out,
        allow_input_downcast=True)
#show_shapes(gen)
#define discriminator network
disc=[]
disc.append(ll.InputLayer(shape=(None,1,28,28)))
disc.append(ll.dropout(Conv2DLayer(disc[-1],num_filters=gen_filters,stride=(2,2),filter_size=(5,5),
	nonlinearity=nonlin.very_leaky_rectify),p=0.25))

disc.append(ll.dropout(Conv2DLayer(disc[-1],num_filters=64,
	filter_size=(5,5),
	nonlinearity=nonlin.very_leaky_rectify),p=0.35))


#disc.append(ll.dropout(ll.DenseLayer(disc[-1],num_units=1000),p=0.3))
#disc.append(ll.dropout(ll.DenseLayer(disc[-1],num_units=500),p=0.2))
#disc.append(ll.dropout(ll.DenseLayer(disc[-1],num_units=250),p=0.3))
#disc.append(ll.GaussianNoiseLayer(disc[-1], sigma=0.01))
disc.append(ll.dropout(ll.DenseLayer(disc[-1], num_units=512,nonlinearity=nonlin.very_leaky_rectify),p=0.0))


disc.append(ll.DenseLayer(disc[-1],num_units=1,nonlinearity=nonlin.sigmoid))
disc_data=ll.get_output(disc[-1],inputs=Xvar)
Exemple #7
0
def build_pixelcnn_block(incoming_vert,
                         incoming_hor,
                         fsize,
                         i,
                         masked=None,
                         latent=None):
    net = OrderedDict()
    # input (batch_size, n_features, n_rows, n_columns), n_features = p
    assert incoming_vert.output_shape[1] == incoming_hor.output_shape[1]
    nfilts = incoming_hor.output_shape[1]  # 2p

    # vertical nxn convolution part, fsize = (n,n)
    if masked:
        # either masked
        net['conv_vert_{}'.format(i)] = Conv2DLayer(incoming_vert,
                                                    num_filters=2 * nfilts,
                                                    filter_size=fsize,
                                                    pad='same',
                                                    nonlinearity=linear,
                                                    name='conv_vert')  # 2p
        f_shape = net.values()[-1].W.get_value(borrow=True).shape
        net.values()[-1].W *= get_mask(f_shape, 'A')
    else:
        # or (n//2+1, n) convolution with padding and croppding
        net['conv_vert_{}'.format(i)] = Conv2DLayer(
            incoming_vert,
            num_filters=2 * nfilts,
            filter_size=(fsize // 2 + 1, fsize),
            pad=(fsize // 2 + 1, fsize // 2),
            nonlinearity=linear,
            name='conv_vert')  # 2p

        # crop
        net['slice_vert'] = SliceLayer(net.values()[-1],
                                       indices=slice(0, -fsize // 2 - 1),
                                       axis=2,
                                       name='slice_vert')

    # vertical gated processing
    l_out_vert, gated_vert = get_gated(net.values()[-1], 2 * nfilts, i, 'vert',
                                       latent)
    net.update(gated_vert)  # p

    # vertical skip connection to horizontal stack
    net['full_conv_vert_{}'.format(i)] = Conv2DLayer(l_out_vert,
                                                     num_filters=2 * nfilts,
                                                     filter_size=1,
                                                     pad='same',
                                                     nonlinearity=linear,
                                                     name='full_conv_vert')
    skip_vert2hor = net.values()[-1]

    # horizontal 1xn convolution part, fsize = (1,n)
    if masked:
        net['conv_hor_{}'.format(i)] = Conv2DLayer(incoming_hor,
                                                   num_filters=2 * nfilts,
                                                   filter_size=(1, fsize),
                                                   pad='same',
                                                   nonlinearity=linear,
                                                   name='conv_hor')  # 2p
        f_shape = net.values()[-1].W.get_value(borrow=True).shape
        net.values()[-1].W *= get_mask(f_shape, 'A')
    else:
        net['conv_hor_{}'.format(i)] = Conv2DLayer(incoming_hor,
                                                   num_filters=2 * nfilts,
                                                   filter_size=(1, fsize // 2 +
                                                                1),
                                                   pad=(0, fsize // 2 + 1),
                                                   nonlinearity=linear,
                                                   name='conv_hor')  # 2p

        # crop
        net['slice_hor'] = SliceLayer(net.values()[-1],
                                      indices=slice(0, -fsize // 2 - 1),
                                      axis=3,
                                      name='slice_hor')

    # merge results of vertical and horizontal convolutions
    net['add_vert2hor_{}'.format(i)] = ElemwiseMergeLayer(
        [skip_vert2hor, net.values()[-1]], T.add, name='add_vert2hor')  # 2p

    # horizontal gated processing
    l_gated_hor, gated_hor = get_gated(net.values()[-1], 2 * nfilts, i, 'hor',
                                       latent)
    net.update(gated_hor)  # p

    # horizontal full convolution
    net['conv_hor_{}'.format(i)] = Conv2DLayer(l_gated_hor,
                                               num_filters=nfilts,
                                               filter_size=1,
                                               pad='same',
                                               nonlinearity=linear,
                                               name='conv_hor')

    # add horizontal skip connection
    net['add_skip2hor_{}'.format(i)] = ElemwiseMergeLayer(
        [net.values()[-1], incoming_hor], T.add, name='add_skip2hor')

    return net, l_out_vert, net.values()[-1]  # net, vert output, hor output
Exemple #8
0
def build_pixel_cnn(input_shape,
                    nfilts=384,
                    fsize=5,
                    n_layers=20,
                    masked=True,
                    latent=None):
    net = OrderedDict()
    if input_shape[0] > 1:
        out_dim = 3 * 256
        out_fn = linear
    else:
        out_dim = 1
        out_fn = sigmoid

    if latent:
        net['latent'] = latent

    net['input'] = InputLayer((None, ) + input_shape, name='input')
    net['input_conv'] = Conv2DLayer(net.values()[-1],
                                    num_filters=nfilts,
                                    filter_size=7,
                                    pad='same',
                                    name='input_conv')
    f_shape = net.values()[-1].W.get_value(borrow=True).shape
    net.values()[-1].W *= get_mask(f_shape, 'A')
    l_vert = l_hor = net.values()[-1]

    for i in range(n_layers):
        block, l_vert, l_hor = build_pixelcnn_block(l_vert,
                                                    l_hor,
                                                    fsize,
                                                    i,
                                                    masked,
                                                    latent=latent)
        net.update(block)

    net['pre_relu'] = NonlinearityLayer(net.values()[-1], name='pre_relu')

    net['pre_output'] = Conv2DLayer(net.values()[-1],
                                    num_filters=nfilts,
                                    filter_size=1,
                                    name='pre_output')  # contains relu
    f_shape = net.values()[-1].W.get_value(borrow=True).shape
    net.values()[-1].W *= get_mask(f_shape, 'A')

    net['output'] = Conv2DLayer(net.values()[-1],
                                num_filters=out_dim,
                                filter_size=1,
                                nonlinearity=out_fn,
                                name='output')
    f_shape = net.values()[-1].W.get_value(borrow=True).shape
    net.values()[-1].W *= get_mask(f_shape, 'A')

    if input_shape[0] > 1:
        output_shape = (input_shape[0], 256, 3, input_shape[2], input_shape[3])
        net['output'] = reshape(net.values()[-1],
                                shape=output_shape,
                                name='output')

        net['output'] = NonlinearityLayer(net.values()[-1],
                                          nonlinearity=softmax,
                                          name='output')

    return net, net['output']
Exemple #9
0
def build_pixelrnn_block(incoming, i, connected=False, learn_init=True):
    net = OrderedDict()
    num_units = incoming.output_shape[1] // 2

    net['skew_{}'.format(i)] = SkewLayer(incoming, name='skew')
    if connected:
        # igul implementation
        net['rnn_{}'.format(i)] = PixelLSTMLayer(net.values()[-1],
                                                 num_units=num_units,
                                                 learn_init=learn_init,
                                                 mask_type='B',
                                                 name='rnn_conn')
        net['bi_rnn_{}'.format(i)] = PixelLSTMLayer(net.values()[-1],
                                                    num_units=num_units,
                                                    learn_init=learn_init,
                                                    mask_type='B',
                                                    backwards=True,
                                                    name='birnn_conn')
    else:
        # original paper says:
        # Given the two output maps, to prevent the layer from seeing future
        # pixels, the right output map is then shifted down by one row and
        # added to the left output map
        skew_l = net.values()[-1]
        rnn_l = net['rnn_{}'.format(i)] = PixelLSTMLayer(skew_l,
                                                         num_units=num_units,
                                                         precompute_input=True,
                                                         learn_init=learn_init,
                                                         mask_type='B',
                                                         name='rnn')
        # W = net.values()[-1].W_in_to_ingate
        # f_shape = np.array(W.get_value(borrow=True).shape)
        # f_shape[1] *= 4
        # W *= get_mask(tuple(f_shape), 'B')

        net['bi_rnn_{}'.format(i)] = PixelLSTMLayer(skew_l,
                                                    num_units=num_units,
                                                    precompute_input=True,
                                                    learn_init=learn_init,
                                                    mask_type='B',
                                                    name='birnn')
        # W = net.values()[-1].W_in_to_ingate
        # f_shape = np.array(W.get_value(borrow=True).shape)
        # f_shape[1] *= 4
        # W *= get_mask(tuple(f_shape), 'B')

        # slice the last row
        net['slice_last_row'] = SliceLayer(net.values()[-1],
                                           indices=slice(0, -1),
                                           axis=2,
                                           name='slice_birnn')

        # pad first row with zeros
        net['pad'] = pad(net.values()[-1],
                         width=[(1, 0)],
                         val=0,
                         batch_ndim=2,
                         name='pad_birnn')

        # add together
        net['rnn_out'] = ElemwiseMergeLayer([rnn_l, net.values()[-1]],
                                            merge_function=T.add,
                                            name='add_rnns')

    net['unskew_{}'.format(i)] = UnSkewLayer(net.values()[-1], name='skew')

    # 1x1 upsampling by full convolution
    nfilts = incoming.output_shape[1]
    net['full_deconv_{}'.format(i)] = Conv2DLayer(net.values()[-1],
                                                  num_filters=nfilts,
                                                  filter_size=1,
                                                  name='full_conv')

    # residual skip connection
    net['skip_{}'.format(i)] = ElemwiseMergeLayer(
        [incoming, net.values()[-1]], merge_function=T.add, name='add_rnns')

    return net
Exemple #10
0
def build_pixel_nn(dataset='mnist', type='rnn'):
    if dataset == 'mnist':
        input_shape = (1, 28, 28)
        n_layers = 7
        n_units = 16
        top_units = 32
        out_dim = 1
        out_fn = sigmoid
    elif dataset == 'cifar10':
        input_shape = (3, 28, 28)
        n_layers = 12
        n_units = 128
        top_units = 1024
        out_dim = 3 * 256
        out_fn = linear

    if type == 'cnn':
        n_units = 128
        n_layers = 15

    net = OrderedDict()
    net['input'] = InputLayer((None, ) + input_shape, name='input')
    net['input_conv'] = Conv2DLayer(net.values()[-1],
                                    num_filters=2 * n_units,
                                    filter_size=7,
                                    pad='same',
                                    name='input_conv')
    f_shape = net.values()[-1].W.get_value(borrow=True).shape
    net.values()[-1].W *= get_mask(f_shape, 'A')

    for i in range(n_layers):
        if type == 'cnn':
            block = build_pixelcnn_block(net.values()[-1], i)
        else:
            block = build_pixelrnn_block(net.values()[-1], i)
        net.update(block)

    net['pre_relu'] = NonlinearityLayer(net.values()[-1], name='pre_relu')

    net['pre_output'] = Conv2DLayer(net.values()[-1],
                                    num_filters=top_units,
                                    filter_size=1,
                                    name='pre_output')  # contains relu
    f_shape = net.values()[-1].W.get_value(borrow=True).shape
    net.values()[-1].W *= get_mask(f_shape, 'B')

    net['output'] = Conv2DLayer(net.values()[-1],
                                num_filters=out_dim,
                                filter_size=1,
                                nonlinearity=out_fn,
                                name='output')
    f_shape = net.values()[-1].W.get_value(borrow=True).shape
    net.values()[-1].W *= get_mask(f_shape, 'B')

    if dataset == 'cifar10':
        output_shape = (input_shape[0], 256, 3, input_shape[2], input_shape[3])
        net['output'] = reshape(net.values()[-1],
                                shape=output_shape,
                                name='output')

        net['output'] = NonlinearityLayer(net.values()[-1],
                                          nonlinearity=softmax,
                                          name='output')

    return net, net['output']
Exemple #11
0
    def get_layers(self):
        l = InputLayer([None, self.in_chans, self.input_time_length, 1])
        if self.split_first_layer:
            l = DimshuffleLayer(l, pattern=[0, 3, 2, 1])
            l = DropoutLayer(l, p=self.drop_in_prob)
            l = Conv2DLayer(l,
                            num_filters=self.num_filters_time,
                            filter_size=[self.filter_time_length, 1],
                            nonlinearity=identity,
                            name='time_conv')
            if self.double_time_convs:
                l = Conv2DLayer(l,
                                num_filters=self.num_filters_time,
                                filter_size=[self.filter_time_length, 1],
                                nonlinearity=identity,
                                name='time_conv')
            l = Conv2DAllColsLayer(l,
                                   num_filters=self.num_filters_spat,
                                   filter_size=[1, -1],
                                   nonlinearity=identity,
                                   name='spat_conv')
        else:  #keep channel dim in first dim, so it will also be convolved over
            l = DropoutLayer(l, p=self.drop_in_prob)
            l = Conv2DLayer(l,
                            num_filters=self.num_filters_time,
                            filter_size=[self.filter_time_length, 1],
                            nonlinearity=identity,
                            name='time_conv')
            if self.double_time_convs:
                l = Conv2DLayer(l,
                                num_filters=self.num_filters_time,
                                filter_size=[self.filter_time_length, 1],
                                nonlinearity=identity,
                                name='time_conv')
        if self.batch_norm:
            l = BatchNormLayer(l,
                               epsilon=1e-4,
                               alpha=self.batch_norm_alpha,
                               nonlinearity=self.first_nonlin)
        else:
            l = NonlinearityLayer(l, nonlinearity=self.first_nonlin)
        l = Pool2DLayer(l,
                        pool_size=[self.pool_time_length, 1],
                        stride=[1, 1],
                        mode=self.first_pool_mode)
        l = StrideReshapeLayer(l, n_stride=self.pool_time_stride)
        l = NonlinearityLayer(l, self.first_pool_nonlin)

        def conv_pool_block(l, num_filters, filter_length, i_block):
            l = DropoutLayer(l, p=self.drop_prob)
            l = Conv2DLayer(l,
                            num_filters=num_filters,
                            filter_size=[filter_length, 1],
                            nonlinearity=identity,
                            name='combined_conv_{:d}'.format(i_block))
            if self.double_time_convs:
                l = Conv2DLayer(l,
                                num_filters=num_filters,
                                filter_size=[filter_length, 1],
                                nonlinearity=identity,
                                name='combined_conv_{:d}'.format(i_block))
            if self.batch_norm:
                l = BatchNormLayer(l,
                                   epsilon=1e-4,
                                   alpha=self.batch_norm_alpha,
                                   nonlinearity=self.later_nonlin)
            else:
                l = NonlinearityLayer(l, nonlinearity=self.later_nonlin)
            l = Pool2DLayer(l,
                            pool_size=[self.pool_time_length, 1],
                            stride=[1, 1],
                            mode=self.later_pool_mode)
            l = StrideReshapeLayer(l, n_stride=self.pool_time_stride)
            l = NonlinearityLayer(l, self.later_pool_nonlin)
            return l

        l = conv_pool_block(l, self.num_filters_2, self.filter_length_2, 2)
        l = conv_pool_block(l, self.num_filters_3, self.filter_length_3, 3)
        l = conv_pool_block(l, self.num_filters_4, self.filter_length_4, 4)
        # Final part, transformed dense layer
        l = DropoutLayer(l, p=self.drop_prob)
        l = Conv2DLayer(l,
                        num_filters=self.n_classes,
                        filter_size=[self.final_dense_length, 1],
                        nonlinearity=identity,
                        name='final_dense')
        l = FinalReshapeLayer(l)
        l = NonlinearityLayer(l, self.final_nonlin)
        return lasagne.layers.get_all_layers(l)
Exemple #12
0
 def conv_layer(incoming, num_filters):
     tmp = Conv2DLayer(incoming, num_filters, 3, pad='valid')
     tmp = BatchNormLayer(tmp)
     if dropout:
         tmp = DropoutLayer(tmp, 0.3)
     return NonlinearityLayer(tmp)
def build_model():
    """ Compile net architecture """

    l_in = lasagne.layers.InputLayer(shape=(None, INPUT_SHAPE[0],
                                            INPUT_SHAPE[1], INPUT_SHAPE[2]),
                                     name='Input')
    net1 = batch_norm(l_in)

    # --- preprocessing ---
    net1 = conv_bn(net1,
                   num_filters=10,
                   filter_size=1,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=1,
                   filter_size=1,
                   nonlinearity=nonlin,
                   pad='same',
                   name='color_deconv_preproc')

    # number of filters in first layer
    # decreased by factor 2 in each block
    nf0 = 16

    # --- encoder ---
    net1 = conv_bn(net1,
                   num_filters=nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    p1 = net1
    net1 = MaxPool2DLayer(net1, pool_size=2, stride=2, name='pool1')

    net1 = conv_bn(net1,
                   num_filters=2 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=2 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    p2 = net1
    net1 = MaxPool2DLayer(net1, pool_size=2, stride=2, name='pool2')

    net1 = conv_bn(net1,
                   num_filters=4 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=4 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    p3 = net1
    net1 = MaxPool2DLayer(net1, pool_size=2, stride=2, name='pool3')

    net1 = conv_bn(net1,
                   num_filters=8 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=8 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')

    # --- decoder ---
    net1 = TransposedConv2DLayer(net1,
                                 num_filters=4 * nf0,
                                 filter_size=2,
                                 stride=2,
                                 name='upconv')
    net1 = ConcatLayer((p3, net1), name='concat')
    net1 = conv_bn(net1,
                   num_filters=4 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=4 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')

    net1 = TransposedConv2DLayer(net1,
                                 num_filters=2 * nf0,
                                 filter_size=2,
                                 stride=2,
                                 name='upconv')
    net1 = ConcatLayer((p2, net1), name='concat')
    net1 = conv_bn(net1,
                   num_filters=2 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=2 * nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')

    net1 = TransposedConv2DLayer(net1,
                                 num_filters=nf0,
                                 filter_size=2,
                                 stride=2,
                                 name='upconv')
    net1 = ConcatLayer((p1, net1), name='concat')
    net1 = conv_bn(net1,
                   num_filters=nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')
    net1 = conv_bn(net1,
                   num_filters=nf0,
                   filter_size=3,
                   nonlinearity=nonlin,
                   pad='same')

    net1 = Conv2DLayer(net1,
                       num_filters=1,
                       filter_size=1,
                       nonlinearity=sigmoid,
                       pad='same',
                       name='segmentation')

    return net1