def createInceptionSegNet(input_shape,
                          n_labels,
                          pool_size=(2, 2),
                          output_mode="sigmoid"):
    # encoder
    inputs = Input(shape=input_shape)

    conv_1 = inceptionModule(inputs)
    conv_2 = inceptionModule(conv_1)
    pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2)

    conv_3 = inceptionModule(pool_1)
    conv_4 = inceptionModule(conv_3)
    pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4)

    ## encoding done, decoding start

    unpool_1 = MaxUnpooling2D(pool_size)([pool_2, mask_2])
    conv_4 = inceptionModule(unpool_1)
    conv_5 = inceptionModule(conv_4)

    unpool_2 = MaxUnpooling2D(pool_size)([conv_5, mask_1])
    conv_5 = inceptionModule(unpool_2)
    conv_6 = inceptionModule(conv_5)

    conv_7 = Convolution2D(n_labels, (1, 1), padding='valid')(conv_6)

    reshape = Reshape((n_labels, input_shape[0] * input_shape[1]))(conv_7)
    permute = Permute((2, 1))(reshape)
    outputs = Activation(output_mode)(permute)

    segnet = Model(inputs=inputs, outputs=outputs)
    return segnet
Example #2
0
def create2LayerSegNetWithIndexPooling(input_shape, 
                                       n_labels, 
                                       k,
                                       kernel=3, 
                                       pool_size=(2, 2),
                                        output_mode="sigmoid"):
        # encoder
    inputs = Input(shape=input_shape)

    conv_1 = Convolution2D(k, (kernel, kernel), padding="same")(inputs)
    conv_1 = BatchNormalization()(conv_1)
    conv_1 = PReLU()(conv_1)
    conv_2 = Convolution2D(k, (kernel, kernel), padding="same")(conv_1)
    conv_2 = BatchNormalization()(conv_2)
    conv_2 = PReLU()(conv_2)

    pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2)
    
    conv_3 = Convolution2D(2*k, (kernel, kernel), padding="same")(pool_1)
    conv_3 = BatchNormalization()(conv_3)
    conv_3 = PReLU()(conv_3)
    conv_4 = Convolution2D(2*k, (kernel, kernel), padding="same")(conv_3)
    conv_4 = BatchNormalization()(conv_4)
    conv_4 = PReLU()(conv_4)

    pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4)
    
    ## encoder done, decoder start ##
         
    unpool_2 = MaxUnpooling2D(pool_size)([pool_2, mask_2])

    conv_9 = Convolution2D(2*k, (kernel, kernel), padding="same")(unpool_2)
    conv_9 = BatchNormalization()(conv_9)
    conv_9 = PReLU()(conv_9)
    conv_10 = Convolution2D(2*k, (kernel, kernel), padding="same")(conv_9)
    conv_10 = BatchNormalization()(conv_10)
    conv_10 = PReLU()(conv_10)
    
    unpool_3 =  MaxUnpooling2D(pool_size)([conv_10, mask_1])
    
    conv_11 = Convolution2D(k, (kernel, kernel), padding="same")(unpool_3)
    conv_11 = BatchNormalization()(conv_11)
    conv_11 = PReLU()(conv_11)
    conv_12 = Convolution2D(k, (kernel, kernel), padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)
    conv_12 = PReLU()(conv_12)
  
    conv_13 = Convolution2D(n_labels, (1, 1), padding='valid')(conv_12)
    conv_13 = BatchNormalization()(conv_13)
    
    reshape = Reshape((n_labels, input_shape[0] * input_shape[1]))(conv_13)
    permute = Permute((2, 1))(reshape)
    outputs = Activation(output_mode)(permute)
   

    segnet = Model(inputs=inputs, outputs=outputs)
    return segnet
Example #3
0
def CreateSegNet(input_shape,
                 n_labels,
                 kernel=3,
                 pool_size=(2, 2),
                 output_mode="softmax"):
    # encoder
    inputs = Input(shape=input_shape)

    conv_1 = Convolution2D(64, (kernel, kernel), padding="same")(inputs)
    conv_1 = BatchNormalization()(conv_1)
    conv_1 = Activation("relu")(conv_1)
    conv_2 = Convolution2D(64, (kernel, kernel), padding="same")(conv_1)
    conv_2 = BatchNormalization()(conv_2)
    conv_2 = Activation("relu")(conv_2)

    pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2)

    conv_3 = Convolution2D(128, (kernel, kernel), padding="same")(pool_1)
    conv_3 = BatchNormalization()(conv_3)
    conv_3 = Activation("relu")(conv_3)
    conv_4 = Convolution2D(128, (kernel, kernel), padding="same")(conv_3)
    conv_4 = BatchNormalization()(conv_4)
    conv_4 = Activation("relu")(conv_4)

    pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4)

    conv_5 = Convolution2D(256, (kernel, kernel), padding="same")(pool_2)
    conv_5 = BatchNormalization()(conv_5)
    conv_5 = Activation("relu")(conv_5)
    conv_6 = Convolution2D(256, (kernel, kernel), padding="same")(conv_5)
    conv_6 = BatchNormalization()(conv_6)
    conv_6 = Activation("relu")(conv_6)
    conv_7 = Convolution2D(256, (kernel, kernel), padding="same")(conv_6)
    conv_7 = BatchNormalization()(conv_7)
    conv_7 = Activation("relu")(conv_7)

    pool_3, mask_3 = MaxPoolingWithArgmax2D(pool_size)(conv_7)

    conv_8 = Convolution2D(512, (kernel, kernel), padding="same")(pool_3)
    conv_8 = BatchNormalization()(conv_8)
    conv_8 = Activation("relu")(conv_8)
    conv_9 = Convolution2D(512, (kernel, kernel), padding="same")(conv_8)
    conv_9 = BatchNormalization()(conv_9)
    conv_9 = Activation("relu")(conv_9)
    conv_10 = Convolution2D(512, (kernel, kernel), padding="same")(conv_9)
    conv_10 = BatchNormalization()(conv_10)
    conv_10 = Activation("relu")(conv_10)

    pool_4, mask_4 = MaxPoolingWithArgmax2D(pool_size)(conv_10)

    conv_11 = Convolution2D(512, (kernel, kernel), padding="same")(pool_4)
    conv_11 = BatchNormalization()(conv_11)
    conv_11 = Activation("relu")(conv_11)
    conv_12 = Convolution2D(512, (kernel, kernel), padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)
    conv_12 = Activation("relu")(conv_12)
    conv_13 = Convolution2D(512, (kernel, kernel), padding="same")(conv_12)
    conv_13 = BatchNormalization()(conv_13)
    conv_13 = Activation("relu")(conv_13)

    pool_5, mask_5 = MaxPoolingWithArgmax2D(pool_size)(conv_13)
    print("Build encoder done..")

    # decoder

    unpool_1 = MaxUnpooling2D(pool_size)([pool_5, mask_5])

    conv_14 = Convolution2D(512, (kernel, kernel), padding="same")(unpool_1)
    conv_14 = BatchNormalization()(conv_14)
    conv_14 = Activation("relu")(conv_14)
    conv_15 = Convolution2D(512, (kernel, kernel), padding="same")(conv_14)
    conv_15 = BatchNormalization()(conv_15)
    conv_15 = Activation("relu")(conv_15)
    conv_16 = Convolution2D(512, (kernel, kernel), padding="same")(conv_15)
    conv_16 = BatchNormalization()(conv_16)
    conv_16 = Activation("relu")(conv_16)

    unpool_2 = MaxUnpooling2D(pool_size)([conv_16, mask_4])

    conv_17 = Convolution2D(512, (kernel, kernel), padding="same")(unpool_2)
    conv_17 = BatchNormalization()(conv_17)
    conv_17 = Activation("relu")(conv_17)
    conv_18 = Convolution2D(512, (kernel, kernel), padding="same")(conv_17)
    conv_18 = BatchNormalization()(conv_18)
    conv_18 = Activation("relu")(conv_18)
    conv_19 = Convolution2D(256, (kernel, kernel), padding="same")(conv_18)
    conv_19 = BatchNormalization()(conv_19)
    conv_19 = Activation("relu")(conv_19)

    unpool_3 = MaxUnpooling2D(pool_size)([conv_19, mask_3])

    conv_20 = Convolution2D(256, (kernel, kernel), padding="same")(unpool_3)
    conv_20 = BatchNormalization()(conv_20)
    conv_20 = Activation("relu")(conv_20)
    conv_21 = Convolution2D(256, (kernel, kernel), padding="same")(conv_20)
    conv_21 = BatchNormalization()(conv_21)
    conv_21 = Activation("relu")(conv_21)
    conv_22 = Convolution2D(128, (kernel, kernel), padding="same")(conv_21)
    conv_22 = BatchNormalization()(conv_22)
    conv_22 = Activation("relu")(conv_22)

    unpool_4 = MaxUnpooling2D(pool_size)([conv_22, mask_2])

    conv_23 = Convolution2D(128, (kernel, kernel), padding="same")(unpool_4)
    conv_23 = BatchNormalization()(conv_23)
    conv_23 = Activation("relu")(conv_23)
    conv_24 = Convolution2D(64, (kernel, kernel), padding="same")(conv_23)
    conv_24 = BatchNormalization()(conv_24)
    conv_24 = Activation("relu")(conv_24)

    unpool_5 = MaxUnpooling2D(pool_size)([conv_24, mask_1])

    conv_25 = Convolution2D(64, (kernel, kernel), padding="same")(unpool_5)
    conv_25 = BatchNormalization()(conv_25)
    conv_25 = Activation("relu")(conv_25)

    conv_26 = Convolution2D(n_labels, (1, 1), padding="valid")(conv_25)
    conv_26 = BatchNormalization()(conv_26)
    conv_26 = Reshape(
        (input_shape[0] * input_shape[1], n_labels),
        input_shape=(input_shape[0], input_shape[1], n_labels))(conv_26)

    outputs = Activation(output_mode)(conv_26)
    print("Build decoder done..")

    segnet = Model(inputs=inputs, outputs=outputs, name="SegNet")

    return segnet
Example #4
0
def create3LayerSegNetWithIndexPooling(input_shape, 
                                       n_labels, 
                                       k,
                                       kernel=3, 
                                       pool_size=(2, 2), 
                                       output_mode="sigmoid"):
    inputs = Input(shape=input_shape)

    conv_1 = Convolution2D(k, (kernel, kernel), padding="same")(inputs)
    conv_1 = BatchNormalization()(conv_1)
    conv_1 = PReLU()(conv_1)
    conv_2 = Convolution2D(k, (kernel, kernel), padding="same")(conv_1)
    #conv_2 = Dropout(0.5)(conv_2)
    conv_2 = BatchNormalization()(conv_2)
    conv_2 = PReLU()(conv_2)

    pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2)
    
    conv_3 = Convolution2D(2*k, (kernel, kernel), padding="same")(pool_1)
    conv_3 = BatchNormalization()(conv_3)
    conv_3 = PReLU()(conv_3)
    conv_4 = Convolution2D(2*k, (kernel, kernel), padding="same")(conv_3)
    #conv_4 = Dropout(0.5)(conv_4)
    conv_4 = BatchNormalization()(conv_4)
    conv_4 = PReLU()(conv_4)
    
    pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4)
    
    conv_5 = Convolution2D(4*k, (kernel, kernel), padding="same")(pool_2)
    conv_5 = BatchNormalization()(conv_5)
    conv_5 = PReLU()(conv_5)
    conv_6 = Convolution2D(4*k, (kernel, kernel), padding="same")(conv_5)
    conv_6 = BatchNormalization()(conv_6)
    conv_6 = PReLU()(conv_6)
    conv_7 = Convolution2D(4*k, (kernel, kernel), padding="same")(conv_6)
    conv_7 = BatchNormalization()(conv_7)
    conv_7 = PReLU()(conv_7)
    
    pool_3, mask_3 = MaxPoolingWithArgmax2D(pool_size)(conv_7)
    
    unpool_1 = MaxUnpooling2D(pool_size)([pool_3, mask_3])
    
    conv_8 = Convolution2D(4*k, (kernel, kernel), padding="same")(unpool_1)
    conv_8 = BatchNormalization()(conv_8)
    conv_8 = PReLU()(conv_8)
    conv_9 = Convolution2D(4*k, (kernel, kernel), padding="same")(conv_8)
    conv_9 = BatchNormalization()(conv_9)
    conv_9 = PReLU()(conv_9)
    conv_10 = Convolution2D(4*k, (kernel, kernel), padding="same")(conv_9)
    conv_10 = BatchNormalization()(conv_10)
    conv_10 = PReLU()(conv_10)
    
    unpool_2 = MaxUnpooling2D(pool_size)([conv_10, mask_2])

    conv_11 = Convolution2D(2*k, (kernel, kernel), padding="same")(unpool_2)
    conv_11 = BatchNormalization()(conv_11)
    conv_11 = PReLU()(conv_11)
    conv_12 = Convolution2D(2*k, (kernel, kernel), padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)
    conv_12 = PReLU()(conv_12)
    
    unpool_3 = MaxUnpooling2D(pool_size)([conv_12, mask_1])
    
    conv_13 = Convolution2D(k, (kernel, kernel), padding="same")(unpool_3)
    conv_13 = BatchNormalization()(conv_13)
    conv_13 = PReLU()(conv_13)
    conv_14 = Convolution2D(k, (kernel, kernel), padding="same")(conv_13)
    conv_14 = BatchNormalization()(conv_14)
    conv_14 = PReLU()(conv_14)
  
    conv_15 = Convolution2D(n_labels, (1, 1), padding='valid')(conv_14)
    conv_15 = BatchNormalization()(conv_15)
    
    reshape = Reshape((n_labels, input_shape[0] * input_shape[1]))(conv_15)
    permute = Permute((2, 1))(reshape)
    outputs = Activation(output_mode)(permute)
    
    segnet = Model(inputs=inputs, outputs=outputs)
    return segnet
def main():
    hardwareHandler = HardwareHandler()
    emailHandler = EmailHandler()
    timer = TimerModule()
    now = datetime.now()
    date_string = now.strftime('%Y-%m-%d_%H_%M')

    dataHandler = SegNetDataHandler("Data/BRATS_2018/HGG",
                                    BasicNMFComputer(block_dim=8),
                                    num_patients=1)
    dataHandler.loadData("flair")
    dataHandler.preprocessForNetwork()
    x_train = dataHandler.X
    x_seg_train = dataHandler.labels
    dataHandler.clear()

    dataHandler.setDataDirectory("Data/BRATS_2018/HGG_Validation")
    dataHandler.setNumPatients(1)
    dataHandler.loadData("flair")
    dataHandler.preprocessForNetwork()
    x_val = dataHandler.X
    x_seg_val = dataHandler.labels
    dataHandler.clear()

    dataHandler.setDataDirectory("Data/BRATS_2018/HGG_Testing")
    dataHandler.setNumPatients(1)
    dataHandler.loadData("flair")
    dataHandler.preprocessForNetwork()
    x_test = dataHandler.X
    x_seg_test = dataHandler.labels
    dataHandler.clear()

    input_shape = (dataHandler.W, dataHandler.H, 1)
    inputs = Input(shape=input_shape)
    kernel = 3
    pool_size = (2, 2)
    output_mode = "softmax"
    n_labels = 2

    conv_1 = Conv2D(64, (kernel, kernel), padding="same")(inputs)
    conv_1 = BatchNormalization()(conv_1)
    conv_1 = Activation("relu")(conv_1)
    conv_2 = Conv2D(64, (kernel, kernel), padding="same")(conv_1)
    conv_2 = BatchNormalization()(conv_2)
    conv_2 = Activation("relu")(conv_2)

    pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2)

    conv_3 = Conv2D(128, (kernel, kernel), padding="same")(pool_1)
    conv_3 = BatchNormalization()(conv_3)
    conv_3 = Activation("relu")(conv_3)
    conv_4 = Conv2D(128, (kernel, kernel), padding="same")(conv_3)
    conv_4 = BatchNormalization()(conv_4)
    conv_4 = Activation("relu")(conv_4)

    pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4)

    conv_5 = Conv2D(256, (kernel, kernel), padding="same")(pool_2)
    conv_5 = BatchNormalization()(conv_5)
    conv_5 = Activation("relu")(conv_5)
    conv_6 = Conv2D(256, (kernel, kernel), padding="same")(conv_5)
    conv_6 = BatchNormalization()(conv_6)
    conv_6 = Activation("relu")(conv_6)
    conv_7 = Conv2D(256, (kernel, kernel), padding="same")(conv_6)
    conv_7 = BatchNormalization()(conv_7)
    conv_7 = Activation("relu")(conv_7)

    pool_3, mask_3 = MaxPoolingWithArgmax2D(pool_size)(conv_7)

    conv_8 = Conv2D(512, (kernel, kernel), padding="same")(pool_3)
    conv_8 = BatchNormalization()(conv_8)
    conv_8 = Activation("relu")(conv_8)
    conv_9 = Conv2D(512, (kernel, kernel), padding="same")(conv_8)
    conv_9 = BatchNormalization()(conv_9)
    conv_9 = Activation("relu")(conv_9)
    conv_10 = Conv2D(512, (kernel, kernel), padding="same")(conv_9)
    conv_10 = BatchNormalization()(conv_10)
    conv_10 = Activation("relu")(conv_10)

    pool_4, mask_4 = MaxPoolingWithArgmax2D(pool_size)(conv_10)

    conv_11 = Conv2D(512, (kernel, kernel), padding="same")(pool_4)
    conv_11 = BatchNormalization()(conv_11)
    conv_11 = Activation("relu")(conv_11)
    conv_12 = Conv2D(512, (kernel, kernel), padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)
    conv_12 = Activation("relu")(conv_12)
    conv_13 = Conv2D(512, (kernel, kernel), padding="same")(conv_12)
    conv_13 = BatchNormalization()(conv_13)
    conv_13 = Activation("relu")(conv_13)

    pool_5, mask_5 = MaxPoolingWithArgmax2D(pool_size)(conv_13)
    print("Build encoder done..")

    # decoder

    unpool_1 = MaxUnpooling2D(pool_size)([pool_5, mask_5])

    conv_14 = Conv2D(512, (kernel, kernel), padding="same")(unpool_1)
    conv_14 = BatchNormalization()(conv_14)
    conv_14 = Activation("relu")(conv_14)
    conv_15 = Conv2D(512, (kernel, kernel), padding="same")(conv_14)
    conv_15 = BatchNormalization()(conv_15)
    conv_15 = Activation("relu")(conv_15)
    conv_16 = Conv2D(512, (kernel, kernel), padding="same")(conv_15)
    conv_16 = BatchNormalization()(conv_16)
    conv_16 = Activation("relu")(conv_16)

    unpool_2 = MaxUnpooling2D(pool_size)([conv_16, mask_4])

    conv_17 = Conv2D(512, (kernel, kernel), padding="same")(unpool_2)
    conv_17 = BatchNormalization()(conv_17)
    conv_17 = Activation("relu")(conv_17)
    conv_18 = Conv2D(512, (kernel, kernel), padding="same")(conv_17)
    conv_18 = BatchNormalization()(conv_18)
    conv_18 = Activation("relu")(conv_18)
    conv_19 = Conv2D(256, (kernel, kernel), padding="same")(conv_18)
    conv_19 = BatchNormalization()(conv_19)
    conv_19 = Activation("relu")(conv_19)

    unpool_3 = MaxUnpooling2D(pool_size)([conv_19, mask_3])

    conv_20 = Conv2D(256, (kernel, kernel), padding="same")(unpool_3)
    conv_20 = BatchNormalization()(conv_20)
    conv_20 = Activation("relu")(conv_20)
    conv_21 = Conv2D(256, (kernel, kernel), padding="same")(conv_20)
    conv_21 = BatchNormalization()(conv_21)
    conv_21 = Activation("relu")(conv_21)
    conv_22 = Conv2D(128, (kernel, kernel), padding="same")(conv_21)
    conv_22 = BatchNormalization()(conv_22)
    conv_22 = Activation("relu")(conv_22)

    unpool_4 = MaxUnpooling2D(pool_size)([conv_22, mask_2])

    conv_23 = Conv2D(128, (kernel, kernel), padding="same")(unpool_4)
    conv_23 = BatchNormalization()(conv_23)
    conv_23 = Activation("relu")(conv_23)
    conv_24 = Conv2D(64, (kernel, kernel), padding="same")(conv_23)
    conv_24 = BatchNormalization()(conv_24)
    conv_24 = Activation("relu")(conv_24)

    unpool_5 = MaxUnpooling2D(pool_size)([conv_24, mask_1])

    conv_25 = Conv2D(64, (kernel, kernel), padding="same")(unpool_5)
    conv_25 = BatchNormalization()(conv_25)
    conv_25 = Activation("relu")(conv_25)

    conv_26 = Conv2D(n_labels, (1, 1), padding="valid")(conv_25)
    conv_26 = BatchNormalization()(conv_26)
    conv_26 = Reshape(
        (input_shape[0] * input_shape[1], n_labels),
        input_shape=(input_shape[0], input_shape[1], n_labels))(conv_26)

    outputs = Activation(output_mode)(conv_26)
    print("Build decoder done..")

    segnet = Model(inputs=inputs, outputs=outputs, name="SegNet")
    segnet.compile(optimizer='adadelta', loss='categorical_crossentropy')

    segnet.fit(
        x_train,
        x_seg_train,
        epochs=10,
        batch_size=10,
        shuffle=True,
        validation_data=(x_val, x_seg_val),
    )

    decoded_imgs = segnet.predict(x_test)

    n = 10
    for i in range(n):
        fig = plt.figure()
        plt.gray()

        a = fig.add_subplot(1, 2, 1)
        plt.imshow(x_seg_test[i].reshape(dataHandler.W, dataHandler.W))
        plt.axis('off')
        plt.title('Original')

        a = fig.add_subplot(1, 2, 2)
        plt.imshow(decoded_imgs[i].reshape(dataHandler.W, dataHandler.W))
        plt.gray()
        plt.show()

    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i)
        plt.imshow(x_test[i].reshape(dataHandler.W, dataHandler.W))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + n)
        plt.imshow(decoded_imgs[i].reshape(dataHandler.W, dataHandler.W))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()
    """
Example #6
0
def SegNetEncoderDecoderGenerator(inputs,
                                  layers,
                                  pool_size=(2, 2),
                                  shave_off_decoder_end=0):
    '''
    Creates an convolutional symmetric encoder-decoder architecture similar to
    that of a SegNet. The exact structure is defined by the 'layers' argument
    which should be a list of tuples, each tuple representing the number and
    width of convolutional layers at each stage of encoder/decoder.
    '''

    masks = []

    io = inputs
    total_layers = 0

    # encoder
    with tf.name_scope('encoder'):
        for idx, group in enumerate(layers):
            layers_no = group[0]
            width = group[1]
            kernel = group[2]
            for i in range(layers_no):
                with tf.name_scope('encoder_{0}x{0}_{1}ch'.format(
                        kernel, width)):
                    io = Convolution2D(width, (kernel, kernel),
                                       padding="same")(io)
                    io = BatchNormalization()(io)
                    io = Activation("relu")(io)
                total_layers += 1

            io, mask = MaxPoolingWithArgmax2D(pool_size)(io)
            masks.append(mask)

    print("Building enceder done..")

    # decoder
    if shave_off_decoder_end > 0:
        total_layers -= shave_off_decoder_end
    with tf.name_scope('decoder'):
        for idx, group in enumerate(reversed(layers)):
            layers_no = group[0]
            width = group[1]
            kernel = group[2]

            io = MaxUnpooling2D(pool_size)([io, masks[-1 - idx]])

            for i in range(layers_no):
                # if this is the last convolution in a series, followed up by
                # unpooling and then a convolution group with a different width, it
                # has to reduce width of the tensor now
                if i == layers_no - 1 and idx != len(layers) - 1:
                    # last layer before UnPooling
                    width = layers[-1 - idx - 1][1]

                with tf.name_scope('decoder_{0}x{0}_{1}ch'.format(
                        kernel, width)):
                    io = Convolution2D(width, (kernel, kernel),
                                       padding="same")(io)
                    io = BatchNormalization()(io)
                    io = Activation("relu")(io)
                total_layers -= 1
                if total_layers <= 0:
                    return io, masks[:-1 - idx]

    print("Building decoder done..")
Example #7
0
    conv_11 = Conv2D(512, (kernel, kernel), padding="same")(pool_4)
    conv_11 = BatchNormalization()(conv_11)
    conv_11 = Activation("relu")(conv_11)
    conv_12 = Conv2D(512, (kernel, kernel), padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)
    conv_12 = Activation("relu")(conv_12)
    conv_13 = Conv2D(512, (kernel, kernel), padding="same")(conv_12)
    conv_13 = BatchNormalization()(conv_13)
    conv_13 = Activation("relu")(conv_13)

    pool_5, mask_5 = MaxPoolingWithArgmax2D(pool_size)(conv_13)
    print("Build enceder done..")

    # decoder

    unpool_1 = MaxUnpooling2D(pool_size)([pool_5, mask_5])

    conv_14 = Conv2D(512, (kernel, kernel), padding="same")(unpool_1)
    conv_14 = BatchNormalization()(conv_14)
    conv_14 = Activation("relu")(conv_14)
    conv_15 = Conv2D(512, (kernel, kernel), padding="same")(conv_14)
    conv_15 = BatchNormalization()(conv_15)
    conv_15 = Activation("relu")(conv_15)
    conv_16 = Conv2D(512, (kernel, kernel), padding="same")(conv_15)
    conv_16 = BatchNormalization()(conv_16)
    conv_16 = Activation("relu")(conv_16)

    unpool_2 = MaxUnpooling2D(pool_size)([conv_16, mask_4])

    conv_17 = Conv2D(512, (kernel, kernel), padding="same")(unpool_2)
    conv_17 = BatchNormalization()(conv_17)