Beispiel #1
0
def cslim(network, cropsz, batchsz):
    # 1st
    network = Conv2DLayer(network, 64, (5, 5), stride=2, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (5, 5), stride=2)
    # 2nd
    network = Conv2DLayer(network, 96, (5, 5), stride=1, pad='same', W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (5, 5), stride=2)
    # 3rd
    network = Conv2DLayer(network, 128, (3, 3), stride=1, pad='same', W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)
    # 4th
    network = Conv2DLayer(network, 128, (3, 3), stride=1, pad='same', W=HeUniform('relu'))
    network = prelu(network)
    network = DropoutLayer(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)
    # 5th
    network = lasagne.layers.DenseLayer(network, 512, nonlinearity=None)
    network = DropoutLayer(network)
    network = FeaturePoolLayer(network, 2)

    return network
Beispiel #2
0
def gooey_gadget(network_in, conv_add, stride):
    network_c = Conv2DLayer(network_in, conv_add / 2, (1, 1),
        W=HeUniform('relu'))
    network_c = prelu(network_c)
    network_c = BatchNormLayer(network_c)
    network_c = Conv2DLayer(network_c, conv_add, (3, 3), stride=stride,
        W=HeUniform('relu'))
    network_c = prelu(network_c)
    network_c = BatchNormLayer(network_c)
    network_p = MaxPool2DLayer(network_in, (3, 3), stride=stride)
    return ConcatLayer((network_c, network_p))
Beispiel #3
0
def gooey_gadget(network_in, conv_add, stride):
    network_c = Conv2DLayer(network_in,
                            conv_add / 2, (1, 1),
                            W=HeUniform('relu'))
    network_c = prelu(network_c)
    network_c = BatchNormLayer(network_c)
    network_c = Conv2DLayer(network_c,
                            conv_add, (3, 3),
                            stride=stride,
                            W=HeUniform('relu'))
    network_c = prelu(network_c)
    network_c = BatchNormLayer(network_c)
    network_p = MaxPool2DLayer(network_in, (3, 3), stride=stride)
    return ConcatLayer((network_c, network_p))
Beispiel #4
0
def gooey(network, cropsz, batchsz):
    # 1st. Data size 117 -> 111 -> 55
    # 117*117*32 = 438048
    network = Conv2DLayer(network, 32, (3, 3), stride=1,
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    # 115*115*32 = 423200
    network = Conv2DLayer(network, 32, (3, 3), stride=1,
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    # 55*55*48 = 121000
    network = Conv2DLayer(network, 40, (3, 3), stride=1,
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 2nd. Data size 55 -> 27
    # 27*27*96 = 69984
    network = Conv2DLayer(network, 96, (3, 3), stride=2,
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)

    # 3rd.  Data size 27 -> 13, 192 + 144
    # 13*13*224 = 37856
    network = gooey_gadget(network, 128, 2) # 92 + 128 = 224 channels

    # 4th.  Data size 13 -> 11 -> 5
    # 11*11*192 = 23232
    network = Conv2DLayer(network, 192, (3, 3),
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)

    # 5*5*412 = 10400
    network = gooey_gadget(network, 224, 2) # 192 + 224 = 416 channels

    # 5th. Data size 5 -> 3
    # 3*3*672 = 6048
    network = gooey_gadget(network, 256, 1) # 416 + 256 = 672 channels

    # 6th. Data size 3 -> 1, 592 + 512 channels
    # 1*1*1184 = 1184
    network = gooey_gadget(network, 512, 1) # 672 + 512 = 1184 channels

    return network
    def __init__(self,
                 incoming,
                 num_filters,
                 filter_size=3,
                 stride=(1, 1, 1),
                 pad='same',
                 W=lasagne.init.HeNormal(),
                 b=None,
                 **kwargs):
        # Enforce name
        ensure_set_name('conv3d_prelu', kwargs)

        super(Conv3DPrelu, self).__init__(incoming, **kwargs)
        self.conv = Conv3D(incoming,
                           num_filters,
                           filter_size,
                           stride,
                           pad=pad,
                           W=W,
                           b=b,
                           nonlinearity=None,
                           **kwargs)
        self.prelu = prelu(self.conv, **kwargs)

        self.params = self.conv.params.copy()
        self.params.update(self.prelu.params)
Beispiel #6
0
def gooey(network, cropsz, batchsz):
    # 1st. Data size 117 -> 111 -> 55
    # 117*117*32 = 438048
    network = Conv2DLayer(network, 32, (3, 3), stride=1, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    # 115*115*32 = 423200
    network = Conv2DLayer(network, 32, (3, 3), stride=1, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    # 55*55*48 = 121000
    network = Conv2DLayer(network, 40, (3, 3), stride=1, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 2nd. Data size 55 -> 27
    # 27*27*96 = 69984
    network = Conv2DLayer(network, 96, (3, 3), stride=2, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)

    # 3rd.  Data size 27 -> 13, 192 + 144
    # 13*13*224 = 37856
    network = gooey_gadget(network, 128, 2)  # 92 + 128 = 224 channels

    # 4th.  Data size 13 -> 11 -> 5
    # 11*11*192 = 23232
    network = Conv2DLayer(network, 192, (3, 3), W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)

    # 5*5*412 = 10400
    network = gooey_gadget(network, 224, 2)  # 192 + 224 = 416 channels

    # 5th. Data size 5 -> 3
    # 3*3*672 = 6048
    network = gooey_gadget(network, 256, 1)  # 416 + 256 = 672 channels

    # 6th. Data size 3 -> 1, 592 + 512 channels
    # 1*1*1184 = 1184
    network = gooey_gadget(network, 512, 1)  # 672 + 512 = 1184 channels

    return network
Beispiel #7
0
def choosy(network, cropsz, batchsz):
    # 1st. Data size 117 -> 111 -> 55
    network = Conv2DLayer(network, 64, (7, 7), stride=1, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 2nd. Data size 55 -> 27
    network = Conv2DLayer(network,
                          112, (5, 5),
                          stride=1,
                          pad='same',
                          W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 3rd.  Data size 27 -> 13
    network = Conv2DLayer(network,
                          192, (3, 3),
                          stride=1,
                          pad='same',
                          W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 4th.  Data size 11 -> 5
    network = Conv2DLayer(network, 320, (3, 3), stride=1, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 5th. Data size 5 -> 3
    network = Conv2DLayer(network, 512, (3, 3), nonlinearity=None)
    network = prelu(network)
    network = BatchNormLayer(network)

    # 6th. Data size 3 -> 1
    network = lasagne.layers.DenseLayer(network, 512, nonlinearity=None)
    network = DropoutLayer(network)
    network = FeaturePoolLayer(network, 2)

    return network
    def __init__(self, incoming, num_filters, filter_size=3, stride=(1, 1, 1),
                 pad='same', W=lasagne.init.HeNormal(), b=None, **kwargs):
        # Enforce name
        ensure_set_name('conv3d_prelu', kwargs)

        super(Conv3DPrelu, self).__init__(incoming, **kwargs)
        self.conv = Conv3D(incoming, num_filters, filter_size, stride,
                           pad=pad, W=W, b=b, nonlinearity=None, **kwargs)
        self.prelu = prelu(self.conv, **kwargs)

        self.params = self.conv.params.copy()
        self.params.update(self.prelu.params)
Beispiel #9
0
def choosy(network, cropsz, batchsz):
    # 1st. Data size 117 -> 111 -> 55
    network = Conv2DLayer(network, 64, (7, 7), stride=1,
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 2nd. Data size 55 -> 27
    network = Conv2DLayer(network, 112, (5, 5), stride=1, pad='same',
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 3rd.  Data size 27 -> 13
    network = Conv2DLayer(network, 192, (3, 3), stride=1, pad='same',
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 4th.  Data size 11 -> 5
    network = Conv2DLayer(network, 320, (3, 3), stride=1,
        W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)

    # 5th. Data size 5 -> 3
    network = Conv2DLayer(network, 512, (3, 3), nonlinearity=None)
    network = prelu(network)
    network = BatchNormLayer(network)

    # 6th. Data size 3 -> 1
    network = lasagne.layers.DenseLayer(network, 512, nonlinearity=None)
    network = DropoutLayer(network)
    network = FeaturePoolLayer(network, 2)

    return network
Beispiel #10
0
def cslim(network, cropsz, batchsz):
    # 1st
    network = Conv2DLayer(network, 64, (5, 5), stride=2, W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (5, 5), stride=2)
    # 2nd
    network = Conv2DLayer(network,
                          96, (5, 5),
                          stride=1,
                          pad='same',
                          W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (5, 5), stride=2)
    # 3rd
    network = Conv2DLayer(network,
                          128, (3, 3),
                          stride=1,
                          pad='same',
                          W=HeUniform('relu'))
    network = prelu(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)
    # 4th
    network = Conv2DLayer(network,
                          128, (3, 3),
                          stride=1,
                          pad='same',
                          W=HeUniform('relu'))
    network = prelu(network)
    network = DropoutLayer(network)
    network = BatchNormLayer(network)
    network = MaxPool2DLayer(network, (3, 3), stride=2)
    # 5th
    network = lasagne.layers.DenseLayer(network, 512, nonlinearity=None)
    network = DropoutLayer(network)
    network = FeaturePoolLayer(network, 2)

    return network
Beispiel #11
0
    def buildDilated3DNet(self, inputShape = (None, 4, 25, 25, 25)):

        summaryRowList = [['-', '-', '-', '-', '-', '-']]
        summaryRowList.append(['Numbering', 'Layer', 'Input Shape', '', 'W Shape', 'Output Shape'])
        summaryRowList.append(['-', '-', '-', '-', '-', '-'])
        dilated3DNet = InputLayer(self.inputShape, self.inputVar, name = 'InputLayer')
        # ........................................................................................
        # For summary
        num = 1
        layerName = 'Input'
        inputS1 = inputShape
        inputS2 = ''
        WShape = ''
        outputS = get_output_shape(dilated3DNet, input_shapes = inputShape)
        summaryRowList.append([num, layerName, inputS1, inputS2, WShape, outputS])
        # ........................................................................................
        layerBlockNum = len(self.kernelNumList) - 1

        for idx in xrange(layerBlockNum):

            dilatedLayer = DilatedConv3DLayer(dilated3DNet, 
                                              self.kernelNumList[idx], 
                                              self.kernelShapeList[idx], 
                                              self.dilatedFactorList[idx], 
                                              W = HeNormal(gain = 'relu'),
                                              nonlinearity = linear)
            # ....................................................................................
            # For summary
            num = idx + 2
            layerName = 'Dilated'
            inputS1 = get_output_shape(dilated3DNet, input_shapes = inputShape)
            inputS2 = ''
            WShape = dilatedLayer.W.get_value().shape
            outputS = get_output_shape(dilatedLayer, input_shapes = inputShape)
            summaryRowList.append([num, layerName, inputS1, inputS2, WShape, outputS])
            # ....................................................................................

            batchNormLayer = BatchNormLayer(dilatedLayer)
            preluLayer = prelu(batchNormLayer)
            concatLayer = ConcatLayer([preluLayer, dilatedLayer], 1, cropping = ['center', 
                                                                                  'None', 
                                                                                  'center', 
                                                                                  'center', 
                                                                                  'center'])
            # ....................................................................................
            # For summary
            num = ''
            layerName = 'Concat'
            inputS1 = get_output_shape(dilatedLayer, input_shapes = inputShape)
            inputS2 = get_output_shape(dilated3DNet, input_shapes = inputShape)
            WShape = ''
            outputS = get_output_shape(concatLayer, input_shapes = inputShape)
            summaryRowList.append([num, layerName, inputS1, inputS2, WShape, outputS])
            # ....................................................................................

            dilated3DNet = DropoutLayer(concatLayer, self.dropoutRates)


        dilatedLayer = DilatedConv3DLayer(dilated3DNet, 
                                          self.kernelNumList[-1], 
                                          self.kernelShapeList[-1], 
                                          self.dilatedFactorList[-1], 
                                          W = HeNormal(gain = 'relu'),
                                          nonlinearity = linear)
        # ....................................................................................
        # For summary
        num = layerBlockNum + 1
        layerName = 'Dilated'
        inputS1 = get_output_shape(dilated3DNet, input_shapes = inputShape)
        inputS2 = ''
        WShape = dilatedLayer.W.get_value().shape
        outputS = get_output_shape(dilatedLayer, input_shapes = inputShape)
        summaryRowList.append([num, layerName, inputS1, inputS2, WShape, outputS])
        # ....................................................................................

        # For receptive field
        receptiveFieldArray = np.asarray(inputShape)[2:] - np.asarray(outputS)[2:] + 1
        assert not np.any(receptiveFieldArray - np.mean(receptiveFieldArray))
        self.receptiveField = int(np.mean(receptiveFieldArray))

        dimshuffleLayer = DimshuffleLayer(dilatedLayer, (0, 2, 3, 4, 1))
        # ....................................................................................
        # For summary
        num = ''
        layerName = 'Dimshuffle'
        inputS1 = get_output_shape(dilatedLayer, input_shapes = inputShape)
        inputS2 = ''
        WShape = ''
        outputS = get_output_shape(dimshuffleLayer, input_shapes = inputShape)
        summaryRowList.append([num, layerName, inputS1, inputS2, WShape, outputS])
        # ....................................................................................

        batchSize, zSize, xSize, ySize, kernelNum = get_output(dimshuffleLayer).shape
        print get_output(dimshuffleLayer).shape, kernelNum
        reshapeLayer = ReshapeLayer(dimshuffleLayer, (batchSize * zSize * xSize * ySize, kernelNum))
        # ....................................................................................
        # For summary
        num = ''
        layerName = 'Reshape'
        inputS1 = get_output_shape(dimshuffleLayer, input_shapes = inputShape)
        inputS2 = ''
        WShape = ''
        outputS = get_output_shape(reshapeLayer, input_shapes = inputShape)
        summaryRowList.append([num, layerName, inputS1, inputS2, WShape, outputS])
        # ....................................................................................

        dilated3DNet = NonlinearityLayer(reshapeLayer, softmax)
        # ....................................................................................
        # For summary
        num = ''
        layerName = 'Nonlinearity'
        inputS1 = get_output_shape(reshapeLayer, input_shapes = inputShape)
        inputS2 = ''
        WShape = ''
        outputS = get_output_shape(dilated3DNet, input_shapes = inputShape)
        summaryRowList.append([num, layerName, inputS1, inputS2, WShape, outputS])
        summaryRowList.append(['-', '-', '-', '-', '-', '-'])
        # ....................................................................................
        self._summary = summaryRowList

        return dilated3DNet
Beispiel #12
0
    def buildBaseNet(self, inputShape=(None, 4, 25, 25, 25), forSummary=False):

        if not forSummary:
            message = 'Building the Architecture of BaseNet'
            self.logger.info(logMessage('+', message))

        baseNet = InputLayer(self.inputShape, self.inputVar)

        if not forSummary:
            message = 'Building the convolution layers'
            self.logger.info(logMessage('-', message))

        kernelShapeListLen = len(self.kernelNumList)

        summary = '\n' + '.' * 130 + '\n'
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Layer', 'Input shape', 'W shape', 'Output shape')
        summary += '.' * 130 + '\n'

        summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format(
            1, 'Input', inputShape, '',
            get_output_shape(baseNet, input_shapes=inputShape))

        for i in xrange(kernelShapeListLen - 1):

            kernelShape = self.kernelShapeList[i]
            kernelNum = self.kernelNumList[i]

            conv3D = Conv3DLayer(incoming=baseNet,
                                 num_filters=kernelNum,
                                 filter_size=kernelShape,
                                 W=HeNormal(gain='relu'),
                                 nonlinearity=linear,
                                 name='Conv3D{}'.format(i))

            # Just for summary the fitler shape.
            WShape = conv3D.W.get_value().shape

            summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format(
                i + 2, 'Conv3D',
                get_output_shape(baseNet, input_shapes=inputShape), WShape,
                get_output_shape(conv3D, input_shapes=inputShape))

            batchNormLayer = BatchNormLayer(conv3D)
            preluLayer = prelu(batchNormLayer)

            concatLayerInputShape = '{:<25}{:<25}'.format(
                get_output_shape(conv3D, input_shapes=inputShape),
                get_output_shape(baseNet, input_shapes=inputShape))

            baseNet = ConcatLayer(
                [preluLayer, baseNet],
                1,
                cropping=['center', 'None', 'center', 'center', 'center'])

            summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
                'Concat', concatLayerInputShape, '',
                get_output_shape(baseNet, input_shapes=inputShape))
        if not forSummary:
            message = 'Finish Built the convolution layers'
            self.logger.info(logMessage('-', message))

            message = 'Building the last classfication layers'
            self.logger.info(logMessage('-', message))

        assert self.kernelShapeList[-1] == [1, 1, 1]

        kernelShape = self.kernelShapeList[-1]
        kernelNum = self.kernelNumList[-1]

        conv3D = Conv3DLayer(incoming=baseNet,
                             num_filters=kernelNum,
                             filter_size=kernelShape,
                             W=HeNormal(gain='relu'),
                             nonlinearity=linear,
                             name='Classfication Layer')

        receptiveFieldList = [
            inputShape[idx] -
            get_output_shape(conv3D, input_shapes=inputShape)[idx] + 1
            for idx in xrange(-3, 0)
        ]
        assert receptiveFieldList != []
        receptiveFieldSet = set(receptiveFieldList)
        assert len(receptiveFieldSet) == 1, (receptiveFieldSet, inputShape,
                                             get_output_shape(
                                                 conv3D,
                                                 input_shapes=inputShape))
        self.receptiveField = list(receptiveFieldSet)[0]

        # Just for summary the fitler shape.
        WShape = conv3D.W.get_value().shape

        summary += '{:<3} {:<15} {:<50} {:<29} {:<29}\n'.format(
            kernelShapeListLen + 1, 'Conv3D',
            get_output_shape(baseNet, input_shapes=inputShape), WShape,
            get_output_shape(conv3D, input_shapes=inputShape))

        # The output shape should be (batchSize, numOfClasses, zSize, xSize, ySize).
        # We will reshape it to (batchSize * zSize * xSize * ySize, numOfClasses),
        # because, the softmax in theano can only receive matrix.

        baseNet = DimshuffleLayer(conv3D, (0, 2, 3, 4, 1))
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Dimshuffle', get_output_shape(conv3D, input_shapes=inputShape),
            '', get_output_shape(baseNet, input_shapes=inputShape))

        batchSize, zSize, xSize, ySize, _ = get_output(baseNet).shape
        reshapeLayerInputShape = get_output_shape(baseNet,
                                                  input_shapes=inputShape)
        baseNet = ReshapeLayer(baseNet,
                               (batchSize * zSize * xSize * ySize, kernelNum))
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Reshape', reshapeLayerInputShape, '',
            get_output_shape(baseNet, input_shapes=inputShape))

        nonlinearityLayerInputShape = get_output_shape(baseNet,
                                                       input_shapes=inputShape)
        baseNet = NonlinearityLayer(baseNet, softmax)
        summary += '    {:<15} {:<50} {:<29} {:<29}\n'.format(
            'Nonlinearity', nonlinearityLayerInputShape, '',
            get_output_shape(baseNet, input_shapes=inputShape))

        if not forSummary:
            message = 'Finish Built the last classfication layers'
            self.logger.info(logMessage('-', message))

            message = 'The Receptivr Field of BaseNet equal {}'.format(
                self.receptiveField)
            self.logger.info(logMessage('*', message))

            message = 'Finish Building the Architecture of BaseNet'
            self.logger.info(logMessage('+', message))

        summary += '.' * 130 + '\n'
        self._summary = summary

        return baseNet
def build_model(weights_path, options):
    """
    Build the CNN model. Create the Neural Net object and return it back. 
    Inputs: 
    - subject name: used to save the net weights accordingly.
    - options: several hyper-parameters used to configure the net.
    
    Output:
    - net: a NeuralNet object 
    """

    net_model_name = options['experiment']

    try:
        os.mkdir(os.path.join(weights_path, net_model_name))
    except:
        pass

    net_weights = os.path.join(weights_path, net_model_name,
                               net_model_name + '.pkl')
    net_history = os.path.join(weights_path, net_model_name,
                               net_model_name + '_history.pkl')

    # select hyper-parameters
    t_verbose = options['net_verbose']
    train_split_perc = options['train_split']
    num_epochs = options['max_epochs']
    max_epochs_patience = options['patience']
    early_stopping = EarlyStopping(patience=max_epochs_patience)
    save_weights = SaveWeights(net_weights, only_best=True, pickle=False)
    save_training_history = SaveTrainingHistory(net_history)

    # build the architecture
    ps = options['patch_size'][0]
    num_channels = 1
    fc_conv = 180
    fc_fc = 180
    dropout_conv = 0.5
    dropout_fc = 0.5

    # --------------------------------------------------
    # channel_1: axial
    # --------------------------------------------------

    axial_ch = InputLayer(name='in1', shape=(None, num_channels, ps, ps))
    axial_ch = prelu(batch_norm(
        Conv2DLayer(axial_ch,
                    name='axial_ch_conv1',
                    num_filters=20,
                    filter_size=3)),
                     name='axial_ch_prelu1')
    axial_ch = prelu(batch_norm(
        Conv2DLayer(axial_ch,
                    name='axial_ch_conv2',
                    num_filters=20,
                    filter_size=3)),
                     name='axial_ch_prelu2')
    axial_ch = MaxPool2DLayer(axial_ch, name='axial_max_pool_1', pool_size=2)
    axial_ch = prelu(batch_norm(
        Conv2DLayer(axial_ch,
                    name='axial_ch_conv3',
                    num_filters=40,
                    filter_size=3)),
                     name='axial_ch_prelu3')
    axial_ch = prelu(batch_norm(
        Conv2DLayer(axial_ch,
                    name='axial_ch_conv4',
                    num_filters=40,
                    filter_size=3)),
                     name='axial_ch_prelu4')
    axial_ch = MaxPool2DLayer(axial_ch, name='axial_max_pool_2', pool_size=2)
    axial_ch = prelu(batch_norm(
        Conv2DLayer(axial_ch,
                    name='axial_ch_conv5',
                    num_filters=60,
                    filter_size=3)),
                     name='axial_ch_prelu5')
    axial_ch = DropoutLayer(axial_ch, name='axial_l1drop', p=dropout_conv)
    axial_ch = DenseLayer(axial_ch, name='axial_d1', num_units=fc_conv)
    axial_ch = prelu(axial_ch, name='axial_prelu_d1')

    # --------------------------------------------------
    # channel_1: coronal
    # --------------------------------------------------

    coronal_ch = InputLayer(name='in2', shape=(None, num_channels, ps, ps))
    coronal_ch = prelu(batch_norm(
        Conv2DLayer(coronal_ch,
                    name='coronal_ch_conv1',
                    num_filters=20,
                    filter_size=3)),
                       name='coronal_ch_prelu1')
    coronal_ch = prelu(batch_norm(
        Conv2DLayer(coronal_ch,
                    name='coronal_ch_conv2',
                    num_filters=20,
                    filter_size=3)),
                       name='coronal_ch_prelu2')
    coronal_ch = MaxPool2DLayer(coronal_ch,
                                name='coronal_max_pool_1',
                                pool_size=2)
    coronal_ch = prelu(batch_norm(
        Conv2DLayer(coronal_ch,
                    name='coronal_ch_conv3',
                    num_filters=40,
                    filter_size=3)),
                       name='coronal_ch_prelu3')
    coronal_ch = prelu(batch_norm(
        Conv2DLayer(coronal_ch,
                    name='coronal_ch_conv4',
                    num_filters=40,
                    filter_size=3)),
                       name='coronal_ch_prelu4')
    coronal_ch = MaxPool2DLayer(coronal_ch,
                                name='coronal_max_pool_2',
                                pool_size=2)
    coronal_ch = prelu(batch_norm(
        Conv2DLayer(coronal_ch,
                    name='coronal_ch_conv5',
                    num_filters=60,
                    filter_size=3)),
                       name='coronal_ch_prelu5')
    coronal_ch = DropoutLayer(coronal_ch,
                              name='coronal_l1drop',
                              p=dropout_conv)
    coronal_ch = DenseLayer(coronal_ch, name='coronal_d1', num_units=fc_conv)
    coronal_ch = prelu(coronal_ch, name='coronal_prelu_d1')

    # --------------------------------------------------
    # channel_1: saggital
    # --------------------------------------------------

    saggital_ch = InputLayer(name='in3', shape=(None, num_channels, ps, ps))
    saggital_ch = prelu(batch_norm(
        Conv2DLayer(saggital_ch,
                    name='saggital_ch_conv1',
                    num_filters=20,
                    filter_size=3)),
                        name='saggital_ch_prelu1')
    saggital_ch = prelu(batch_norm(
        Conv2DLayer(saggital_ch,
                    name='saggital_ch_conv2',
                    num_filters=20,
                    filter_size=3)),
                        name='saggital_ch_prelu2')
    saggital_ch = MaxPool2DLayer(saggital_ch,
                                 name='saggital_max_pool_1',
                                 pool_size=2)
    saggital_ch = prelu(batch_norm(
        Conv2DLayer(saggital_ch,
                    name='saggital_ch_conv3',
                    num_filters=40,
                    filter_size=3)),
                        name='saggital_ch_prelu3')
    saggital_ch = prelu(batch_norm(
        Conv2DLayer(saggital_ch,
                    name='saggital_ch_conv4',
                    num_filters=40,
                    filter_size=3)),
                        name='saggital_ch_prelu4')
    saggital_ch = MaxPool2DLayer(saggital_ch,
                                 name='saggital_max_pool_2',
                                 pool_size=2)
    saggital_ch = prelu(batch_norm(
        Conv2DLayer(saggital_ch,
                    name='saggital_ch_conv5',
                    num_filters=60,
                    filter_size=3)),
                        name='saggital_ch_prelu5')
    saggital_ch = DropoutLayer(saggital_ch,
                               name='saggital_l1drop',
                               p=dropout_conv)
    saggital_ch = DenseLayer(saggital_ch,
                             name='saggital_d1',
                             num_units=fc_conv)
    saggital_ch = prelu(saggital_ch, name='saggital_prelu_d1')

    # FC layer 540
    layer = ConcatLayer(name='elem_channels',
                        incomings=[axial_ch, coronal_ch, saggital_ch])
    layer = DropoutLayer(layer, name='f1_drop', p=dropout_fc)
    layer = DenseLayer(layer, name='FC1', num_units=540)
    layer = prelu(layer, name='prelu_f1')

    # concatenate channels 540 + 15
    layer = DropoutLayer(layer, name='f2_drop', p=dropout_fc)
    atlas_layer = DropoutLayer(InputLayer(name='in4', shape=(None, 15)),
                               name='Dropout_atlas',
                               p=.2)
    atlas_layer = InputLayer(name='in4', shape=(None, 15))
    layer = ConcatLayer(name='elem_channels2', incomings=[layer, atlas_layer])

    # FC layer 270
    layer = DenseLayer(layer, name='fc_2', num_units=270)
    layer = prelu(layer, name='prelu_f2')

    # FC output 15 (softmax)
    net_layer = DenseLayer(layer,
                           name='out_layer',
                           num_units=15,
                           nonlinearity=softmax)

    net = NeuralNet(
        layers=net_layer,
        objective_loss_function=objectives.categorical_crossentropy,
        update=updates.adam,
        update_learning_rate=0.001,
        on_epoch_finished=[
            save_weights,
            save_training_history,
            early_stopping,
        ],
        verbose=t_verbose,
        max_epochs=num_epochs,
        train_split=TrainSplit(eval_size=train_split_perc),
    )

    if options['load_weights'] == 'True':
        try:
            print "    --> loading weights from ", net_weights
            net.load_params_from(net_weights)
        except:
            pass

    return net