Exemplo n.º 1
0
def _pad_or_crop_to_shape_3D(x, in_shape, tgt_shape):
    '''
    in_shape, tgt_shape are both 2x1 numpy arrays
    '''
    im_diff = np.asarray(in_shape[:3]) - np.asarray(tgt_shape[:3])

    if im_diff[0] < 0:
        pad_amt = (int(np.ceil(abs(im_diff[0]) / 2.0)),
                   int(np.floor(abs(im_diff[0]) / 2.0)))
        x = ZeroPadding3D((pad_amt, (0, 0), (0, 0)))(x)
    if im_diff[1] < 0:
        pad_amt = (int(np.ceil(abs(im_diff[1]) / 2.0)),
                   int(np.floor(abs(im_diff[1]) / 2.0)))
        x = ZeroPadding3D(((0, 0), pad_amt, (0, 0)))(x)
    if im_diff[2] < 0:
        pad_amt = (int(np.ceil(abs(im_diff[2]) / 2.0)),
                   int(np.floor(abs(im_diff[2]) / 2.0)))
        x = ZeroPadding3D(((0, 0), (0, 0), pad_amt))(x)

    if im_diff[0] > 0:
        crop_amt = (int(np.ceil(im_diff[0] / 2.0)),
                    int(np.floor(im_diff[0] / 2.0)))
        x = Cropping3D((crop_amt, (0, 0), (0, 0)))(x)
    if im_diff[1] > 0:
        crop_amt = (int(np.ceil(im_diff[1] / 2.0)),
                    int(np.floor(im_diff[1] / 2.0)))
        x = Cropping3D(((0, 0), crop_amt, (0, 0)))(x)
    if im_diff[2] > 0:
        crop_amt = (int(np.ceil(im_diff[2] / 2.0)),
                    int(np.floor(im_diff[2] / 2.0)))
        x = Cropping3D(((0, 0), (0, 0), crop_amt))(x)
    return x
Exemplo n.º 2
0
def make_flood_fill_network(input_fov_shape, output_fov_shape, network_config):
    """Construct a stacked convolution module flood filling network.
    """
    if network_config.convolution_padding != 'same':
        raise ValueError('ResNet implementation only supports same padding.')

    image_input = Input(shape=tuple(input_fov_shape) + (1, ),
                        dtype='float32',
                        name='image_input')
    if network_config.rescale_image:
        ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input)
    else:
        ffn = image_input
    mask_input = Input(shape=tuple(input_fov_shape) + (1, ),
                       dtype='float32',
                       name='mask_input')
    ffn = concatenate([ffn, mask_input])

    # Convolve and activate before beginning the skip connection modules,
    # as discussed in the Appendix of He et al 2016.
    ffn = Conv3D(network_config.convolution_filters,
                 tuple(network_config.convolution_dim),
                 kernel_initializer=network_config.initialization,
                 activation=network_config.convolution_activation,
                 padding='same')(ffn)
    if network_config.batch_normalization:
        ffn = BatchNormalization()(ffn)

    contraction = (input_fov_shape - output_fov_shape) // 2
    if np.any(np.less(contraction, 0)):
        raise ValueError(
            'Output FOV shape can not be larger than input FOV shape.')
    contraction_cumu = np.zeros(3, dtype=np.int32)
    contraction_step = np.divide(contraction,
                                 float(network_config.num_modules))

    for i in range(0, network_config.num_modules):
        ffn = add_convolution_module(ffn, network_config)
        contraction_dims = np.floor(i * contraction_step -
                                    contraction_cumu).astype(np.int32)
        if np.count_nonzero(contraction_dims):
            ffn = Cropping3D(
                zip(list(contraction_dims), list(contraction_dims)))(ffn)
            contraction_cumu += contraction_dims

    if np.any(np.less(contraction_cumu, contraction)):
        remainder = contraction - contraction_cumu
        ffn = Cropping3D(zip(list(remainder), list(remainder)))(ffn)

    mask_output = Conv3D(1,
                         tuple(network_config.convolution_dim),
                         kernel_initializer=network_config.initialization,
                         padding='same',
                         name='mask_output',
                         activation=network_config.output_activation)(ffn)
    ffn = Model(inputs=[image_input, mask_input], outputs=[mask_output])

    return ffn
Exemplo n.º 3
0
    def func(input_tensor, skip_tensor):

        # x = Conv2DTranspose(filters, kernel_size=(2, 2), strides=(2, 2))(input_tensor)
        x = _bn_act_deconv_3D(filters=filters, kernel_size=(2, 2, 2), strides=(2, 2, 2))(input_tensor)

        # computation of cropping-amount needed for skip_tensor
        _, x_height, x_width, x_depth, _ = K.int_shape(x)
        _, s_height, s_width, s_depth, _ = K.int_shape(skip_tensor)

        h_crop = s_height - x_height
        w_crop = s_width - x_width
        d_crop = s_depth - x_depth

        assert h_crop >= 0
        assert w_crop >= 0
        assert d_crop >= 0
        if h_crop == 0 and w_crop == 0 and d_crop == 0:
            y = skip_tensor
        else:
            cropping = ((h_crop // 2, h_crop - h_crop // 2), (w_crop // 2, w_crop - w_crop // 2),
                        (d_crop // 2, d_crop - d_crop // 2))
            y = Cropping3D(cropping=cropping)(skip_tensor)

        x = Concatenate()([x, y])

        x_list = []
        for ii in range(len(dilation_rates)):
            dilation_rate = dilation_rates[ii]
            x = _bn_act_conv_3D(filters=filters, kernel_size=kernel_size, strides=strides, dilation_rate=dilation_rate,
                                activation=activation, batch_norm=batch_norm, dropout=dropout, train=train)(x)
            x_list.append(x)

        return x, x_list
Exemplo n.º 4
0
        def deconv3d(layer_input,
                     skip_input,
                     filters,
                     f_size=4,
                     dropout_rate='self'):
            """Layers used during upsampling"""
            u = UpSampling3D(size=2)(layer_input)
            u = Conv3D(filters, kernel_size=f_size, strides=1,
                       padding='same')(u)
            if dropout_rate == 'self':
                dropout_rate = 0.5 if self.dropout else False
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            if self.adain:
                g = Dense(filters, bias_initializer='ones')(w)
                b = Dense(filters)(w)
                u = Lambda(adain)([u, g, b])
            else:
                u = BatchNormalization(momentum=0.8)(u)
            u = ReLU()(u)

            # u = Concatenate()([u, skip_input])
            ch, cw, cd = get_crop_shape(u, skip_input)
            crop_conv4 = Cropping3D(cropping=(ch, cw, cd),
                                    data_format="channels_last")(u)
            u = Concatenate()([crop_conv4, skip_input])
            return u
Exemplo n.º 5
0
def crop_to_fit(inputs, outputs, n_layers=4):
    width = int_shape(inputs)[1]
    height = int_shape(inputs)[2]
    depth = int_shape(inputs)[3]

    w_pad, h_pad, d_pad = calc_fit_pad(width, height, depth, n_layers)

    x = Cropping3D((w_pad, h_pad, d_pad))(outputs)
    return x
Exemplo n.º 6
0
def SAAM(x, up_scale, N_svd):
    bsize, ang, hei, wid, chn = x.get_shape().as_list()
    # \phi'_q
    h1 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_1')(x)
    h1 = tf.transpose(h1, [0, 2, 1, 3, 4])
    h1 = tf.reshape(h1, [-1, ang * wid, chn // 8])
    if N_svd > 0:
        s1, u1, v1 = tf.svd(h1)
        s1 = tf.slice(s1, [0, 0], [-1, N_svd])
        u1 = tf.slice(u1, [0, 0, 0], [-1, -1, N_svd])
        v1 = tf.slice(v1, [0, 0, 0], [-1, -1, N_svd])

    # \phi'_k
    h2 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_2')(x)
    h2 = tf.transpose(h2, [0, 2, 1, 3, 4])
    h2 = tf.reshape(h2, [-1, ang * wid, chn // 8])
    if N_svd > 0:
        s2, u2, v2 = tf.svd(h2)
        s2 = tf.slice(s2, [0, 0], [-1, N_svd])
        u2 = tf.slice(u2, [0, 0, 0], [-1, -1, N_svd])
        v2 = tf.slice(v2, [0, 0, 0], [-1, -1, N_svd])

    # \phi'_v
    h3 = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_3')(x)
    h3 = tf.transpose(h3, [0, 2, 1, 3, 4])
    h3 = tf.reshape(h3, [-1, ang * wid, chn // 8 * up_scale])

    # Map
    if N_svd > 0:
        attn_EPI = tf.matmul(tf.matmul(u1, tf.matmul(tf.matmul(tf.matrix_diag(s1), tf.matmul(v1, v2, transpose_a=True)),
                                                     tf.matrix_diag(s2), transpose_b=True)), u2, transpose_b=True)
    else:
        attn_EPI = tf.matmul(h1, h2, transpose_b=True)
    attn_EPI = tf.nn.softmax(attn_EPI)

    # \phi'_a
    h = tf.matmul(attn_EPI, h3)

    # \phi'_b
    h = tf.reshape(h, [-1, hei, ang, wid, chn // 8 * up_scale])
    h = tf.transpose(h, [0, 2, 1, 3, 4])
    
    x = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_4')(x)
    sigma = beta_variable([1], name='attn_sigma')
    h = x + sigma * h

    # \phi_c
    h = pixel_shuffle(h, up_scale)
    h = Cropping3D(cropping=((0, up_scale - 1), (0, 0), (0, 0)))(h)
    h = Conv3D(filters=chn, kernel_size=(1, 1, 7), activation='relu', padding='SAME', name='epi_5')(h)
    h = Conv3D(filters=chn, kernel_size=(7, 1, 1), activation='relu', padding='SAME', name='epi_6')(h)
    return h
Exemplo n.º 7
0
    def crop_nodes_to_match(self, node1, node2):
        """
        If necessary, applies Cropping3D layer to node1 to match shape of node2
        """
        s1 = np.array(node1.get_shape().as_list())[1:-1]
        s2 = np.array(node2.get_shape().as_list())[1:-1]

        if np.any(s1 != s2):
            c = (s1 - s2).astype(np.int)
            cr = np.array([c // 2, c // 2]).T
            cr[:, 1] += c % 2
            cropped_node1 = Cropping3D(cr)(node1)
            self.label_crop += cr
        else:
            cropped_node1 = node1

        return cropped_node1
def create_simplified_deepmedic_model(**kwargs):
    from tensorflow.keras.layers import Input, AveragePooling3D, Cropping3D
    from tensorflow.keras import backend as K
    from tensorflow.keras import Model

    assert np.all([
        n_i_f_p_p == kwargs["number_input_features_per_pathway"][0]
        for n_i_f_p_p in kwargs["number_input_features_per_pathway"]
    ])
    if "input_interpolation" in kwargs:
        assert kwargs["input_interpolation"] == "mean"

    else:
        kwargs["input_interpolation"] = "mean"

    model = create_generalized_deepmedic_model(**kwargs)
    input_size = np.max(
        [[s_f_p_p[i] * K.int_shape(input_path)[i + 1] for i in range(3)]
         for s_f_p_p, input_path in zip(
             kwargs["subsample_factors_per_pathway"], model.inputs)],
        axis=0)
    input = Input(
        tuple(input_size) + (kwargs["number_input_features_per_pathway"][0], ))
    paths = []
    for s_f_p_p, input_path in zip(kwargs["subsample_factors_per_pathway"],
                                   model.inputs):
        crop_size = [
            input_size[i] - s_f_p_p[i] * K.int_shape(input_path)[i + 1]
            for i in range(3)
        ]
        path = Cropping3D(
            tuple([(c_z // 2, c_z - c_z // 2) for c_z in crop_size]))(input)
        path = AveragePooling3D(tuple(s_f_p_p))(path)
        paths.append(path)

    inputs = [input
              ] + model.inputs[len(kwargs["subsample_factors_per_pathway"]):]
    outputs = model(
        paths + model.inputs[len(kwargs["subsample_factors_per_pathway"]):])
    model = Model(inputs=inputs,
                  outputs=outputs if isinstance(outputs, list) else [outputs])
    print("\nNetwork summary of simplified DeepMedic model:")
    print(model.summary())
    return model
Exemplo n.º 9
0
    def func(input_tensor, skip_tensor):

        # x = Conv2DTranspose(filters, kernel_size=(2, 2), strides=(2, 2))(input_tensor)
        x = _deconv_bn_act_3D(filters=filters, kernel_size=(2, 2, 2), strides=(2, 2, 2))(input_tensor)


        # # computation of cropping-amount needed for skip_tensor
        _, x_height, x_width, x_depth, _ = K.int_shape(x)
        _, s_height, s_width, s_depth, _ = K.int_shape(skip_tensor)

        h_crop = s_height - x_height
        w_crop = s_width - x_width
        d_crop = s_depth - x_depth

        assert h_crop >= 0
        assert w_crop >= 0
        assert d_crop >= 0
        if h_crop == 0 and w_crop == 0 and d_crop == 0:
            y = skip_tensor
        else:
            cropping = ((h_crop // 2, h_crop - h_crop // 2), (w_crop // 2, w_crop - w_crop // 2),
                        (d_crop // 2, d_crop - d_crop // 2))
            y = Cropping3D(cropping=cropping)(skip_tensor)

            # cropping = ((h_crop // 2, h_crop - h_crop // 2), (w_crop // 2, w_crop - w_crop // 2))
            # y = Cropping2D(cropping=cropping)(skip_tensor)


        # commented out at the moment because of error message:
        # could not create a view primitive descriptor, in file tensorflow/core/kernels/mkl_slice_op.cc:300

        x = Concatenate()([x, y])

        x_list = []
        for ii in range(len(dilation_rates)):
            dilation_rate = dilation_rates[ii]
            x = _conv_bn_act_3D(filters=filters, kernel_size=kernel_size, strides=strides, dilation_rate=dilation_rate,
                                activation=activation, batch_norm=batch_norm, dropout=dropout)(x)
            x_list.append(x)

        return x, x_list
Exemplo n.º 10
0
def test_delete_channels_cropping3d(channel_index, data_format):
    layer = Cropping3D([2, 3, 2], data_format=data_format)
    layer_test_helper_flatten_3d(layer, channel_index, data_format=data_format)
Exemplo n.º 11
0
def bn_feature_net_3D(receptive_field=61,
                      n_frames=5,
                      input_shape=(5, 256, 256, 1),
                      n_features=3,
                      n_channels=1,
                      reg=1e-5,
                      n_conv_filters=64,
                      n_dense_filters=200,
                      VGG_mode=False,
                      init='he_normal',
                      norm_method='std',
                      location=False,
                      dilated=False,
                      padding=False,
                      padding_mode='reflect',
                      multires=False,
                      include_top=True,
                      temporal=None,
                      residual=False,
                      temporal_kernel_size=3):
    """Creates a 3D featurenet.

    Args:
        receptive_field (int): the receptive field of the neural network.
        n_frames (int): Number of frames.
        input_shape (tuple): If no input tensor, create one with this shape.
        n_features (int): Number of output features
        n_channels (int): number of input channels
        reg (int): regularization value
        n_conv_filters (int): number of convolutional filters
        n_dense_filters (int): number of dense filters
        VGG_mode (bool): If ``multires``, uses ``VGG_mode``
            for multiresolution
        init (str): Method for initalizing weights.
        norm_method (str): Normalization method to use with the
            :mod:`deepcell.layers.normalization.ImageNormalization3D` layer.
        location (bool): Whether to include a
            :mod:`deepcell.layers.location.Location3D` layer.
        dilated (bool): Whether to use dilated pooling.
        padding (bool): Whether to use padding.
        padding_mode (str): Type of padding, one of 'reflect' or 'zero'
        multires (bool): Enables multi-resolution mode
        include_top (bool): Whether to include the final layer of the model
        temporal (str): Type of temporal operation
        residual (bool): Whether to use temporal information as a residual
        temporal_kernel_size (int): size of 2D kernel used in temporal convolutions

    Returns:
        tensorflow.keras.Model: 3D FeatureNet
    """
    # Create layers list (x) to store all of the layers.
    # We need to use the functional API to enable the multiresolution mode
    x = []

    win = (receptive_field - 1) // 2
    win_z = (n_frames - 1) // 2

    if dilated:
        padding = True

    if K.image_data_format() == 'channels_first':
        channel_axis = 1
        time_axis = 2
        row_axis = 3
        col_axis = 4
        if not dilated:
            input_shape = (n_channels, n_frames, receptive_field,
                           receptive_field)
    else:
        channel_axis = -1
        time_axis = 1
        row_axis = 2
        col_axis = 3
        if not dilated:
            input_shape = (n_frames, receptive_field, receptive_field,
                           n_channels)

    x.append(Input(shape=input_shape))
    x.append(
        ImageNormalization3D(norm_method=norm_method,
                             filter_size=receptive_field)(x[-1]))

    if padding:
        if padding_mode == 'reflect':
            x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1]))
        elif padding_mode == 'zero':
            x.append(ZeroPadding3D(padding=(win_z, win, win))(x[-1]))

    if location:
        x.append(Location3D()(x[-1]))
        x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]]))

    layers_to_concat = []

    rf_counter = receptive_field
    block_counter = 0
    d = 1

    while rf_counter > 4:
        filter_size = 3 if rf_counter % 2 == 0 else 4
        x.append(
            Conv3D(n_conv_filters, (1, filter_size, filter_size),
                   dilation_rate=(1, d, d),
                   kernel_initializer=init,
                   padding='valid',
                   kernel_regularizer=l2(reg))(x[-1]))
        x.append(BatchNormalization(axis=channel_axis)(x[-1]))
        x.append(Activation('relu')(x[-1]))

        block_counter += 1
        rf_counter -= filter_size - 1

        if block_counter % 2 == 0:
            if dilated:
                x.append(
                    DilatedMaxPool3D(dilation_rate=(1, d, d),
                                     pool_size=(1, 2, 2))(x[-1]))
                d *= 2
            else:
                x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1]))

            if VGG_mode:
                n_conv_filters *= 2

            rf_counter = rf_counter // 2

            if multires:
                layers_to_concat.append(len(x) - 1)

    if multires:
        c = []
        for l in layers_to_concat:
            output_shape = x[l].get_shape().as_list()
            target_shape = x[-1].get_shape().as_list()
            time_crop = (0, 0)

            row_crop = int(output_shape[row_axis] - target_shape[row_axis])

            if row_crop % 2 == 0:
                row_crop = (row_crop // 2, row_crop // 2)
            else:
                row_crop = (row_crop // 2, row_crop // 2 + 1)

            col_crop = int(output_shape[col_axis] - target_shape[col_axis])

            if col_crop % 2 == 0:
                col_crop = (col_crop // 2, col_crop // 2)
            else:
                col_crop = (col_crop // 2, col_crop // 2 + 1)

            cropping = (time_crop, row_crop, col_crop)

            c.append(Cropping3D(cropping=cropping)(x[l]))
        x.append(Concatenate(axis=channel_axis)(c))

    x.append(
        Conv3D(n_dense_filters, (1, rf_counter, rf_counter),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        Conv3D(n_dense_filters, (n_frames, 1, 1),
               dilation_rate=(1, d, d),
               kernel_initializer=init,
               padding='valid',
               kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    feature = Activation('relu')(x[-1])

    def __merge_temporal_features(feature,
                                  mode='conv',
                                  residual=False,
                                  n_filters=256,
                                  n_frames=3,
                                  padding=True,
                                  temporal_kernel_size=3):
        if mode is None:
            return feature

        mode = str(mode).lower()
        if mode == 'conv':
            x = Conv3D(n_filters,
                       (n_frames, temporal_kernel_size, temporal_kernel_size),
                       kernel_initializer=init,
                       padding='same',
                       activation='relu',
                       kernel_regularizer=l2(reg))(feature)
        elif mode == 'lstm':
            x = ConvLSTM2D(filters=n_filters,
                           kernel_size=temporal_kernel_size,
                           padding='same',
                           kernel_initializer=init,
                           activation='relu',
                           kernel_regularizer=l2(reg),
                           return_sequences=True)(feature)
        elif mode == 'gru':
            x = ConvGRU2D(filters=n_filters,
                          kernel_size=temporal_kernel_size,
                          padding='same',
                          kernel_initializer=init,
                          activation='relu',
                          kernel_regularizer=l2(reg),
                          return_sequences=True)(feature)
        else:
            raise ValueError(
                '`temporal` must be one of "conv", "lstm", "gru" or None')

        if residual is True:
            temporal_feature = Add()([feature, x])
        else:
            temporal_feature = x
        temporal_feature_normed = BatchNormalization(
            axis=channel_axis)(temporal_feature)
        return temporal_feature_normed

    temporal_feature = __merge_temporal_features(
        feature,
        mode=temporal,
        residual=residual,
        n_filters=n_dense_filters,
        n_frames=n_frames,
        padding=padding,
        temporal_kernel_size=temporal_kernel_size)
    x.append(temporal_feature)

    x.append(
        TensorProduct(n_dense_filters,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))
    x.append(BatchNormalization(axis=channel_axis)(x[-1]))
    x.append(Activation('relu')(x[-1]))

    x.append(
        TensorProduct(n_features,
                      kernel_initializer=init,
                      kernel_regularizer=l2(reg))(x[-1]))

    if not dilated:
        x.append(Flatten()(x[-1]))

    if include_top:
        x.append(Softmax(axis=channel_axis, dtype=K.floatx())(x[-1]))

    model = Model(inputs=x[0], outputs=x[-1])

    return model
def create_generalized_deepmedic_model(
        number_input_features_per_pathway=(1, 1),
        subsample_factors_per_pathway=((1, 1, 1), (3, 3, 3)),
        kernel_sizes_per_pathway=(((3, 3, 1), ) * 5 + ((3, 3, 3), ) * 5,
                                  ((3, 3, 1), ) * 5 + ((3, 3, 3), ) * 5),
        number_features_per_pathway=((32, ) * 5 + (48, ) * 5,
                                     (32, ) * 5 + (48, ) * 5),
        kernel_sizes_common_pathway=((1, 1, 1), ) * 3,
        number_features_common_pathway=(150, 150, 1),
        dropout_common_pathway=(0, 0.5, 0.5),
        output_size=(22, 15, 9),
        metadata_sizes=None,
        metadata_number_features=None,
        metadata_dropout=None,
        metadata_at_common_pathway_layer=None,
        padding='valid',
        pooling='avg',  # Not used yet; set to 'avg' to allow use with dense_connection=<int>
        upsampling='copy',
        activation='prelu',
        activation_final_layer='sigmoid',
        kernel_initializer='he_normal',
        batch_normalization=False,
        batch_normalization_on_input=False,
        instance_normalization=False,
        instance_normalization_on_input=False,
        relaxed_normalization_scheme=False,  # Every <int> layers we request normalization
        mask_output=False,
        residual_connections=False,  # We group <int> layers into a residual block
        dense_connections=False,  # We group <int> layers into a densely connected block (<int> - 1 layers have dense connections and there is always 1 transition layer in between)
        add_extra_dimension=False,
        l1_reg=0.0,
        l2_reg=0.0,
        verbose=True,
        input_interpolation="nearest"):
    from tensorflow.keras.layers import Input, Dropout, MaxPooling3D, Concatenate, Multiply, Add, Reshape, Conv3DTranspose, AveragePooling3D, Conv3D, UpSampling3D, Cropping3D, LeakyReLU, PReLU, BatchNormalization
    from tensorflow.keras import regularizers
    from tensorflow.keras import backend as K
    from tensorflow_addons.layers import InstanceNormalization
    from tensorflow.keras import Model

    # Define some in-house functions
    def normalization_function():
        if batch_normalization_on_input or batch_normalization:
            normalization_function_ = BatchNormalization()

        elif instance_normalization_on_input or instance_normalization:
            normalization_function_ = InstanceNormalization()

        else:
            raise NotImplementedError

        return normalization_function_

    def introduce_metadata(path, i):
        metadata_input_ = metadata_path = Input(shape=(1, 1, 1,
                                                       metadata_sizes[i]))
        inputs.append(metadata_input_)
        for j, (m_n_f, m_d) in enumerate(
                zip(metadata_number_features[i], metadata_dropout[i])):
            if m_d:
                metadata_path = Dropout(m_d)(metadata_path)
            metadata_path = Conv3D(filters=m_n_f,
                                   kernel_size=(1, 1, 1),
                                   padding=padding,
                                   kernel_initializer=kernel_initializer,
                                   kernel_regularizer=regularizers.l1_l2(
                                       l1_reg, l2_reg))(metadata_path)
            metadata_path = activation_function("m{}_activation{}".format(
                i, j))(metadata_path)

        # When the metadata is inserted, every voxel most likely corresponds to an output.
        # Currently I cannot think of a good reason it wouldn't be the case, hence the hard assert.
        path_size = K.int_shape(path)[1:4]
        assert tuple(output_size) == tuple(path_size)
        metadata_path = UpSampling3D(tuple([int(s) for s in path_size
                                            ]))(metadata_path)
        return Concatenate(axis=-1)([path, metadata_path])

    def activation_function(name):
        if activation == "relu":
            activation_function_ = LeakyReLU(alpha=0, name=name)

        elif activation == "lrelu":
            activation_function_ = LeakyReLU(alpha=0.01, name=name)

        elif activation == "prelu":
            activation_function_ = PReLU(shared_axes=[1, 2, 3], name=name)

        elif activation == "linear":

            def activation_function_(path):
                return path

        else:
            raise NotImplementedError

        return activation_function_

    def pooling_function(pool_size):
        if pooling == "max":
            pooling_function_ = MaxPooling3D(pool_size, pool_size)

        elif pooling == "avg":
            pooling_function_ = AveragePooling3D(pool_size, pool_size)

        else:
            raise NotImplementedError

        return pooling_function_

    def upsampling_function(upsample_size):
        if upsampling == "copy":
            upsampling_function_ = UpSampling3D(upsample_size)

        elif upsampling == "linear":

            def upsampling_function_(path):
                path = UpSampling3D(upsample_size)(path)
                path = AveragePooling3D(upsample_size,
                                        strides=(1, 1, 1),
                                        padding='valid')(path)
                return path

        elif upsampling == "conv":

            def upsampling_function_(path):
                path = Conv3DTranspose(
                    K.int_shape(path)[-1], upsample_size, upsample_size)(path)
                return path

        else:
            raise NotImplementedError

        return upsampling_function_

    # Define some in-house variables
    nb_pathways = len(subsample_factors_per_pathway)
    supported_activations = ["relu", "lrelu", "prelu", "linear"]
    supported_poolings = ["max", "avg"]
    supported_upsamplings = ["copy", "linear", "conv"]
    supported_paddings = ["valid", "same"]

    #Do some sanity checks
    if not len(number_input_features_per_pathway) == len(
            kernel_sizes_per_pathway) == len(
                number_features_per_pathway) == nb_pathways:
        raise ValueError("Inconsistent number of pathways.")

    for p in range(nb_pathways):
        if not len(kernel_sizes_per_pathway[p]) == len(
                number_features_per_pathway[p]):
            raise ValueError("Inconsistent depth of pathway #{}.".format(p))

        for ssf in subsample_factors_per_pathway[p]:
            if ssf % 2 != 1:
                raise ValueError("Subsample factors must be odd.")

    for k_s_p_p, n_f_p_p in zip(kernel_sizes_per_pathway,
                                number_features_per_pathway):
        if not len(k_s_p_p) == len(n_f_p_p):
            raise ValueError(
                "Each kernel size in each element from kernel_sizes_per_pathway must correspond with a number of features in each element of number_features_per_pathway."
            )

    if not len(kernel_sizes_common_pathway) == len(
            dropout_common_pathway) == len(number_features_common_pathway):
        raise ValueError("Inconsistent depth of common pathway.")

    if metadata_sizes in [None, []]:
        metadata_sizes = []
        if metadata_number_features in [None, []]:
            metadata_number_features = []

        else:
            raise ValueError(
                "Invalid value for metadata_number_features when there is no metadata"
            )

        if metadata_dropout in [None, []]:
            metadata_dropout = []

        else:
            raise ValueError(
                "Invalid value for metadata_dropout when there is no metadata")

        if metadata_at_common_pathway_layer in [None, []]:
            metadata_at_common_pathway_layer = []

        else:
            raise ValueError(
                "Invalid value for metadata_at_common_pathway_layer when there is no metadata"
            )

    else:
        if not len(metadata_sizes) == len(metadata_dropout) == len(
                metadata_number_features) == len(
                    metadata_at_common_pathway_layer):
            raise ValueError("Inconsistent depth of metadata pathway.")

    if residual_connections and dense_connections:
        raise ValueError(
            "Residual connections and Dense connections should not be used together."
        )

    if dense_connections and not pooling == "avg":
        raise ValueError(
            "According to Huang et al. a densely connected network should have average pooling."
        )

    if activation not in supported_activations:
        raise ValueError("The chosen activation is not supported.")

    if pooling not in supported_poolings:
        raise ValueError("The chosen pooling is not supported.")

    if upsampling not in supported_upsamplings:
        raise ValueError("The chosen upsampling is not supported.")

    if padding not in supported_paddings:
        raise ValueError("The chosen padding is not supported.")

    if (batch_normalization_on_input
            or batch_normalization) and (instance_normalization_on_input
                                         or instance_normalization):
        raise ValueError(
            "You have to choose between batch or instance normalization.")

    if relaxed_normalization_scheme and not (batch_normalization
                                             or instance_normalization):
        raise ValueError(
            "The relaxed normalization scheme can only be used if you also do (batch or instance) normalization."
        )

    # Calculate the field of view
    field_of_views = []
    input_sizes = []
    for p, (s_f_p_p, k_s_p_p) in enumerate(
            zip(subsample_factors_per_pathway, kernel_sizes_per_pathway)):
        field_of_view = np.ones(3, dtype=int)
        if input_interpolation == "mean":
            field_of_view *= np.array(s_f_p_p)

        for k_s in k_s_p_p:
            field_of_view += (np.array(k_s) - 1) * s_f_p_p

        for k_s in kernel_sizes_common_pathway:
            field_of_view += np.array(k_s) - 1

        field_of_views.append(list(field_of_view))
        input_sizes.append(list(field_of_view - 1 + output_size))

    output_size = list(output_size)
    if verbose:
        for p in range(nb_pathways):
            print(
                "\nfield of view for pathway {}:\t{}\t(theoretical (less meaningful if padding='same' and output size is small))"
                .format(p, field_of_views[p]))

        print("output size:\t{}\t(user defined)".format(output_size))
        for p in range(nb_pathways):
            print(
                "input size for pathway {}:\t{}\t(inferred with theoretical field of view (less meaningful if padding='same'))"
                .format(p, input_sizes[p]))

    # What are the possible input and output sizes?
    input_sizes = output_sizes = np.stack([np.arange(150)] * 3, axis=-1)
    for k_s in reversed(kernel_sizes_common_pathway):
        input_sizes = input_sizes + (np.array(k_s) -
                                     1 if padding == 'valid' else 0)

    input_sizes_per_pathway = []
    sizes_per_pathway_after_upsampling = []
    for p, (s_f_p_p, k_s_p_p) in enumerate(
            zip(subsample_factors_per_pathway, kernel_sizes_per_pathway)):
        sizes_p_before_upsampling = np.ceil(
            (input_sizes - (1 if upsampling == "linear" else 0)) /
            s_f_p_p) + (1 if upsampling == "linear" else 0)
        sizes_p_after_upsampling = (
            sizes_p_before_upsampling -
            (1 if upsampling == "linear" else 0)) * s_f_p_p + (
                1 if upsampling == "linear" else 0)
        input_sizes_p = sizes_p_before_upsampling
        for k_s in k_s_p_p:
            input_sizes_p = input_sizes_p + (np.array(k_s) -
                                             1 if padding == 'valid' else 0)

        input_sizes_per_pathway.append(input_sizes_p)
        sizes_per_pathway_after_upsampling.append(sizes_p_after_upsampling)
        if p > 0:
            output_sizes[(sizes_p_after_upsampling -
                          sizes_per_pathway_after_upsampling[0]) % 2 > 0] = 0

    possible_sizes_per_pathway_after_upsampling = [[
        list(sizes_per_pathway_after_upsampling[p][output_sizes[:, i] > 0,
                                                   i].astype(int))
        for i in range(3)
    ] for p in range(nb_pathways)]
    possible_input_sizes_per_pathway = [[
        list(input_sizes_per_pathway[p][output_sizes[:, i] > 0, i].astype(int))
        for i in range(3)
    ] for p in range(nb_pathways)]
    possible_output_sizes = [
        list(output_sizes[output_sizes[:, i] > 0, i].astype('int'))
        for i in range(3)
    ]
    if verbose and not all([
            o_s in p_o_s
            for o_s, p_o_s in zip(output_size, possible_output_sizes)
    ]):
        print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format(
            *possible_output_sizes))
        for p in range(nb_pathways):
            print(
                "\npossible input sizes for pathway {} (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}"
                .format(p, *possible_input_sizes_per_pathway[p]))

        raise ValueError(
            "The user defined output_size is not possible. Please choose from list above."
        )

    else:
        input_sizes = [[
            int(possible_input_sizes_per_pathway[p][i][
                possible_output_sizes[i].index(o_s)])
            for i, o_s in enumerate(output_size)
        ] for p in range(nb_pathways)]
        sizes_after_upsampling = [[
            int(possible_sizes_per_pathway_after_upsampling[p][i][
                possible_output_sizes[i].index(o_s)])
            for i, o_s in enumerate(output_size)
        ] for p in range(nb_pathways)]
        for p in range(nb_pathways):
            print(
                "input size for pathway {}:\t{}\t(true input size of the network)"
                .format(p, input_sizes[p]))

        print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format(
            *possible_output_sizes))
        for p in range(nb_pathways):
            print(
                "\npossible input sizes for pathway {} (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}"
                .format(p, *possible_input_sizes_per_pathway[p]))

    # Construct model
    inputs = []
    paths = []

    #1. Construct parallel pathways
    for p in range(nb_pathways):
        input_ = path = Input(shape=tuple(input_sizes[p]) +
                              (number_input_features_per_pathway[p], ),
                              name="p{}_input".format(p))
        inputs.append(input_)
        for i, (k_s_p_p, n_f_p_p) in enumerate(
                zip(((), ) + kernel_sizes_per_pathway[p],
                    ((), ) + number_features_per_pathway[p])):
            if i != 0:
                path = Conv3D(filters=n_f_p_p,
                              kernel_size=k_s_p_p,
                              padding=padding,
                              kernel_initializer=kernel_initializer,
                              kernel_regularizer=regularizers.l1_l2(
                                  l1_reg, l2_reg))(path)

            if dense_connections:
                if i % dense_connections != 0:
                    shortcut = Cropping3D([
                        int((l - r) / 2) for l, r in zip(
                            K.int_shape(shortcut)[1:-1],
                            K.int_shape(path)[1:-1])
                    ])(shortcut)
                    path = Concatenate(axis=-1)([path, shortcut])

                if i + 1 != len(kernel_sizes_per_pathway[p]):
                    shortcut = path

            if residual_connections and i % residual_connections == 0:
                if i != 0:
                    shortcut = Cropping3D([
                        int((l - r) / 2) for l, r in zip(
                            K.int_shape(shortcut)[1:-1],
                            K.int_shape(path)[1:-1])
                    ])(shortcut)
                    if K.int_shape(path)[-1] != K.int_shape(shortcut)[-1]:
                        shortcut = Conv3D(
                            filters=K.int_shape(path)[-1],
                            kernel_size=(1, 1, 1),
                            padding=padding,
                            kernel_initializer=kernel_initializer,
                            kernel_regularizer=regularizers.l1_l2(
                                l1_reg, l2_reg))(shortcut)

                    path = Add()([path, shortcut])

                if i + 1 != len(kernel_sizes_per_pathway[p]):
                    shortcut = path

            if not relaxed_normalization_scheme or i % relaxed_normalization_scheme == 0:
                if (i == 0 and
                    (batch_normalization_on_input
                     or instance_normalization_on_input)) or (
                         i != 0 and
                         (batch_normalization or instance_normalization)):
                    path = normalization_function()(path)

            if i != 0:
                path = activation_function("p{}_activation{}".format(p, i -
                                                                     1))(path)

        path = upsampling_function(subsample_factors_per_pathway[p])(path)
        if p > 0:
            path = Cropping3D([
                int((l - r) / 2) for l, r in zip(sizes_after_upsampling[p],
                                                 sizes_after_upsampling[0])
            ])(path)

        paths.append(path)

    # 2. Construct common pathway
    path = Concatenate(axis=-1)(paths) if len(paths) > 1 else paths[0]
    for i, (n_f_c_p, k_s_c_p, d_c_p) in enumerate(
            zip(number_features_common_pathway, kernel_sizes_common_pathway,
                dropout_common_pathway)):
        for j, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer):
            if m_a_c_p_l == i:
                path = introduce_metadata(path, j)

        if d_c_p:
            path = Dropout(d_c_p)(path)

        path = Conv3D(filters=n_f_c_p,
                      kernel_size=k_s_c_p,
                      activation=activation_final_layer if i +
                      1 == len(number_features_common_pathway) else None,
                      padding=padding,
                      kernel_initializer=kernel_initializer,
                      kernel_regularizer=regularizers.l1_l2(l1_reg,
                                                            l2_reg))(path)
        if i + 1 < len(number_features_common_pathway):
            if batch_normalization or instance_normalization:
                path = normalization_function()(path)

            path = activation_function("c_activation{}".format(i))(path)

    # 3. Mask the output (optionally)
    if mask_output:
        mask_input_ = mask_path = Input(shape=tuple(output_size) +
                                        (K.int_shape(path)[-1], ))
        inputs.append(mask_input_)
        path = Multiply()([path, mask_path])

    # 4. For example: Correct for segment sampling changes to P(X|Y) --> this adds an extra dimension because the correction is done inside loss function and weights are given with y_creator in extra dimension (can only be done for binary like this)
    if add_extra_dimension:
        path = Reshape(K.int_shape(path)[1:] + (1, ))(path)

    model = Model(inputs=inputs, outputs=[path])

    # Final sanity check: were our calculations correct?
    if verbose:
        print("\nNetwork summary:")
        print(model.summary())
        model_input_shape = model.input_shape
        if not isinstance(model_input_shape, list):
            model_input_shape = [model_input_shape]

        for p in range(nb_pathways):
            assert list(model_input_shape[p][1:-1]) == input_sizes[p]

        print('With a batch size of {} this model needs {} GB on the GPU.'.
              format(1, get_model_memory_usage(1, model)))

    return model
Exemplo n.º 13
0
def add_unet_layer(model,
                   network_config,
                   remaining_layers,
                   output_shape,
                   n_channels=None):
    if n_channels is None:
        n_channels = model.get_shape().as_list()[-1]

    downsample = np.array([
        x != 0 and remaining_layers % x == 0
        for x in network_config.unet_downsample_rate
    ])

    if network_config.convolution_padding == 'same':
        conv_contract = np.zeros(3, dtype=np.int32)
    else:
        conv_contract = network_config.convolution_dim - 1

    # First U convolution module.
    for i in range(network_config.num_layers_per_module):
        if i == network_config.num_layers_per_module - 1:
            # Increase the number of channels before downsampling to avoid
            # bottleneck (identical to 3D U-Net paper).
            n_channels = 2 * n_channels
        model = Conv3D(n_channels,
                       tuple(network_config.convolution_dim),
                       kernel_initializer=network_config.initialization,
                       activation=network_config.convolution_activation,
                       padding=network_config.convolution_padding)(model)
        if network_config.batch_normalization:
            model = BatchNormalization()(model)

    # Crop and pass forward to upsampling.
    if remaining_layers > 0:
        forward_link_shape = output_shape + network_config.num_layers_per_module * conv_contract
    else:
        forward_link_shape = output_shape
    contraction = (np.array(model.get_shape().as_list()[1:4]) -
                   forward_link_shape) // 2
    forward = Cropping3D(list(zip(list(contraction),
                                  list(contraction))))(model)
    if network_config.dropout_probability > 0.0:
        forward = Dropout(network_config.dropout_probability)(forward)

    # Terminal layer of the U.
    if remaining_layers <= 0:
        return forward

    # Downsample and recurse.
    model = Conv3D(n_channels,
                   tuple(network_config.convolution_dim),
                   strides=list(downsample + 1),
                   kernel_initializer=network_config.initialization,
                   activation=network_config.convolution_activation,
                   padding='same')(model)
    if network_config.batch_normalization:
        model = BatchNormalization()(model)
    next_output_shape = np.ceil(
        np.divide(forward_link_shape,
                  downsample.astype(np.float32) + 1.0)).astype(np.int32)
    model = add_unet_layer(model, network_config, remaining_layers - 1,
                           next_output_shape.astype(np.int32))

    # Upsample output of previous layer and merge with forward link.
    model = Conv3DTranspose(n_channels * 2,
                            tuple(network_config.convolution_dim),
                            strides=list(downsample + 1),
                            kernel_initializer=network_config.initialization,
                            activation=network_config.convolution_activation,
                            padding='same')(model)
    if network_config.batch_normalization:
        model = BatchNormalization()(model)
    # Must crop output because Keras wrongly pads the output shape for odd array sizes.
    stride_pad = (network_config.convolution_dim //
                  2) * np.array(downsample) + (1 -
                                               np.mod(forward_link_shape, 2))
    tf_pad_start = stride_pad // 2  # Tensorflow puts odd padding at end.
    model = Cropping3D(
        list(zip(list(tf_pad_start), list(stride_pad - tf_pad_start))))(model)

    model = concatenate([forward, model])

    # Second U convolution module.
    for _ in range(network_config.num_layers_per_module):
        model = Conv3D(n_channels,
                       tuple(network_config.convolution_dim),
                       kernel_initializer=network_config.initialization,
                       activation=network_config.convolution_activation,
                       padding=network_config.convolution_padding)(model)
        if network_config.batch_normalization:
            model = BatchNormalization()(model)

    return model
Exemplo n.º 14
0
def get_net():
    # Level 1
    input = Input((input_dim, input_dim, input_dim, 1))
    conv1 = Conv3D(32, (3, 3, 3), activation="relu", padding="same")(input)
    batch1 = BatchNormalization()(conv1)
    conv1 = Conv3D(64, (3, 3, 3), activation="relu", padding="same")(batch1)
    batch1 = BatchNormalization()(conv1)

    # Level 2
    pool2 = MaxPooling3D((2, 2, 2))(batch1)
    conv2 = Conv3D(64, (3, 3, 3), activation="relu", padding="same")(pool2)
    batch2 = BatchNormalization()(conv2)
    conv2 = Conv3D(128, (3, 3, 3), activation="relu", padding="same")(batch2)
    batch2 = BatchNormalization()(conv2)

    # Level 3
    pool3 = MaxPooling3D((2, 2, 2))(batch2)
    conv3 = Conv3D(128, (3, 3, 3), activation="relu", padding="same")(pool3)
    batch3 = BatchNormalization()(conv3)
    conv3 = Conv3D(256, (3, 3, 3), activation="relu", padding="same")(batch3)
    batch3 = BatchNormalization()(conv3)

    # Level 4
    pool4 = MaxPooling3D((2, 2, 2))(batch3)
    conv4 = Conv3D(256, (3, 3, 3), activation="relu", padding="same")(pool4)
    batch4 = BatchNormalization()(conv4)
    conv4 = Conv3D(512, (3, 3, 3), activation="relu", padding="same")(batch4)
    batch4 = BatchNormalization()(conv4)

    # Level 3
    up5 = Conv3DTranspose(512, (2, 2, 2), strides=(2, 2, 2), padding="same", activation="relu")(batch4)
    merge5 = concatenate([up5, batch3])
    conv5 = Conv3D(256, (3, 3, 3), activation="relu")(merge5)
    batch5 = BatchNormalization()(conv5)
    conv5 = Conv3D(256, (3, 3, 3), activation="relu")(batch5)
    batch5 = BatchNormalization()(conv5)

    # Level 2
    up6 = Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), activation="relu")(batch5)
    merge6 = concatenate([up6, Cropping3D(cropping=((4, 4), (4, 4), (4, 4)))(batch2)])
    conv6 = Conv3D(128, (3, 3, 3), activation="relu")(merge6)
    batch6 = BatchNormalization()(conv6)
    conv6 = Conv3D(128, (3, 3, 3), activation="relu")(batch6)
    batch6 = BatchNormalization()(conv6)

    # Level 1
    up7 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding="same", activation="relu")(batch6)
    merge7 = concatenate([up7, Cropping3D(cropping=((12, 12), (12, 12), (12, 12)))(batch1)])
    conv7 = Conv3D(64, (3, 3, 3), activation="relu")(merge7)
    batch7 = BatchNormalization()(conv7)
    conv7 = Conv3D(64, (3, 3, 3), activation="relu")(batch7)
    batch7 = BatchNormalization()(conv7)

    # Output dim is (36, 36, 36)
    preds = Conv3D(1, (1, 1, 1), activation="sigmoid")(batch7)
    model = Model(inputs=input, outputs=preds)

    model.compile(optimizer=Adam(lr=0.001, decay=0.00), loss=weighted_binary_crossentropy,
                  metrics=[axon_precision, axon_recall, f1_score, artifact_precision, edge_axon_precision, adjusted_accuracy])

    return model
Exemplo n.º 15
0
hidden.append(Activation('relu')(hidden[-1]))

# hidden.append(tf.nn.batch_normalization(hidden[-1], bnm[len(bna)], bns[len(bna)], 0.0, 1.0, 0.000001))
# hidden.append((hidden[-1]-bnm[len(bna)])/bns[len(bna)])
# bna.append(hidden[-1])
# print('len bna',len(bna))
print('layer',len(hidden)-1,':',hidden[-1].shape,'after conv conv bn')
print('...')

# up
for i in range(len(nFeatMapsList)-1):
    nFeatMaps = nFeatMapsList[-i-2]
    hidden.append(Conv3DTranspose(nFeatMaps,(3),strides=(2),padding='same',activation='relu')(hidden[-1]))
    print('layer',len(hidden)-1,':',hidden[-1].shape,'after upconv')
    toCrop = int((hidden[ccidx[-1-i]].shape[1]-hidden[-1].shape[1])//2)
    hidden.append(concatenate([hidden[-1], Cropping3D(toCrop)(hidden[ccidx[-1-i]])]))
    print('layer',len(hidden)-1,':',hidden[-1].shape,'after concat with cropped layer %d' % ccidx[-1-i])
    
#     hidden.append(Dropout(0.5)(hidden[-1], training=t))
    
    hidden.append(Conv3D(nFeatMaps,(3),padding='valid',activation=None)(hidden[-1]))
    hidden.append(tf.layers.batch_normalization(hidden[-1], training=t))
    hidden.append(Conv3D(nFeatMaps,(3),padding='valid',activation=None)(hidden[-1]))
    hidden.append(tf.layers.batch_normalization(hidden[-1], training=t))
    hidden.append(Activation('relu')(hidden[-1]))
    # hidden.append(tf.nn.batch_normalization(hidden[-1], bnm[len(bna)], bns[len(bna)], 0.0, 1.0, 0.000001))
    # hidden.append((hidden[-1]-bnm[len(bna)])/bns[len(bna)])
    # bna.append(hidden[-1])
    # print('len bna',len(bna))
    print('layer',len(hidden)-1,':',hidden[-1].shape,'after conv conv bn')
    print('...')
Exemplo n.º 16
0
def create_generalized_unet_v2_model(
        number_input_features=4,
        subsample_factors_per_pathway=((1, 1, 1), (2, 2, 2), (4, 4, 4),
                                       (8, 8, 8), (16, 16, 16)),
        kernel_sizes_per_pathway=((((3, 3, 3), (3, 3, 3)),
                                   ((3, 3, 3), (3, 3,
                                                3))), (((3, 3, 3), (3, 3, 3)),
                                                       ((3, 3, 3), (3, 3, 3))),
                                  (((3, 3, 3), (3, 3, 3)), ((3, 3, 3), (3, 3,
                                                                        3))),
                                  (((3, 3, 3), (3, 3, 3)),
                                   ((3, 3, 3), (3, 3, 3))), (((3, 3, 3),
                                                              (3, 3, 3)), ())),
        number_features_per_pathway=(((30, 30), (30, 30)),
                                     ((60, 60), (60, 30)), ((120, 120), (120,
                                                                         60)),
                                     ((240, 240), (240, 120)), ((480, 240),
                                                                ())),
        kernel_sizes_common_pathway=((1, 1, 1), ) * 1,
        number_features_common_pathway=(1, ),
        dropout_common_pathway=(0, ),
        output_size=(128, 128, 128),
        metadata_sizes=None,
        metadata_number_features=None,
        metadata_dropout=None,
        metadata_at_common_pathway_layer=None,
        padding='same',
        pooling='max',
        upsampling='copy',
        activation='lrelu',
        activation_final_layer='sigmoid',
        kernel_initializer='he_normal',
        batch_normalization=False,
        batch_normalization_on_input=False,
        instance_normalization=True,
        instance_normalization_on_input=False,
        relaxed_normalization_scheme=False,
        mask_output=False,
        residual_connections=False,
        dense_connections=False,
        add_extra_dimension=False,
        l1_reg=0.0,
        l2_reg=1e-5,
        verbose=True,
        input_interpolation="nearest",
        number_siam_pathways=1,
        extra_output_kernel_sizes=None,
        extra_output_number_features=None,
        extra_output_dropout=None,
        extra_output_at_common_pathway_layer=None,
        extra_output_activation_final_layer=None,
        dynamic_input_shapes=False):
    from tensorflow.keras.layers import Input, Dropout, MaxPooling3D, Concatenate, Multiply, Add, Reshape, AveragePooling3D, Conv3D, UpSampling3D, Cropping3D, LeakyReLU, PReLU, BatchNormalization, Conv3DTranspose
    from tensorflow.keras import regularizers
    from tensorflow.keras import backend as K
    from tensorflow_addons.layers import InstanceNormalization
    from tensorflow.keras import Model

    # Define some in-house functions
    def normalization_function():
        if batch_normalization_on_input or batch_normalization:
            normalization_function_ = BatchNormalization()

        elif instance_normalization_on_input or instance_normalization:
            normalization_function_ = InstanceNormalization()

        else:
            raise NotImplementedError

        return normalization_function_

    def introduce_metadata(path, i):
        if dynamic_input_shapes:
            metadata_input_ = metadata_path = Input(
                shape=(None, None, None, metadata_sizes[i][-1] if isinstance(
                    metadata_sizes[i], tuple) else metadata_sizes[i]))

        else:
            metadata_input_ = metadata_path = Input(
                shape=metadata_sizes[i] if isinstance(metadata_sizes[i], tuple)
                else (1, 1, 1, metadata_sizes[i]))

        metadata_inputs[i] = metadata_input_
        for j, (m_n_f, m_d) in enumerate(
                zip(metadata_number_features[i], metadata_dropout[i])):
            metadata_path = Dropout(m_d)(metadata_path)
            metadata_path = Conv3D(filters=m_n_f,
                                   kernel_size=(1, 1, 1),
                                   padding=padding,
                                   kernel_initializer=kernel_initializer,
                                   kernel_regularizer=regularizers.l1_l2(
                                       l1_reg, l2_reg))(metadata_path)
            metadata_path = activation_function("m{}_activation{}".format(
                i, j))(metadata_path)

        if not isinstance(metadata_sizes[i], tuple):
            broadcast_shape = K.concatenate([
                K.constant([1], dtype="int32"),
                K.shape(path)[1:-1],
                K.constant([1], dtype="int32")
            ])
            metadata_path = K.tile(metadata_path, broadcast_shape)

        return Concatenate(axis=-1)([path, metadata_path])

    def retrieve_extra_output(path, i):
        extra_output_path = path
        for j, (e_o_k_s, e_o_n_f, e_o_d) in enumerate(
                zip(extra_output_kernel_sizes[i],
                    extra_output_number_features[i], extra_output_dropout[i])):
            extra_output_path = Dropout(e_o_d)(extra_output_path)
            extra_output_path = Conv3D(
                filters=e_o_n_f,
                kernel_size=e_o_k_s,
                activation=extra_output_activation_final_layer[i] if j +
                1 == len(extra_output_number_features[i]) else None,
                padding=padding,
                kernel_initializer=kernel_initializer,
                kernel_regularizer=regularizers.l1_l2(l1_reg, l2_reg),
                name=f"s{i + 1}" if j +
                1 == len(extra_output_number_features[i]) else
                None)(extra_output_path)
            if j + 1 < len(extra_output_number_features[i]):
                if (batch_normalization or instance_normalization
                    ) and not relaxed_normalization_scheme:
                    extra_output_path = normalization_function()(
                        extra_output_path)

                extra_output_path = activation_function(
                    "eo{}_activation{}".format(i, j))(extra_output_path)

        return extra_output_path

    def activation_function(name):
        if activation == "relu":
            activation_function_ = LeakyReLU(alpha=0, name=name)

        elif activation == "lrelu":
            activation_function_ = LeakyReLU(alpha=0.01, name=name)

        elif activation == "prelu":
            activation_function_ = PReLU(shared_axes=[1, 2, 3], name=name)

        elif activation == "linear":

            def activation_function_(path):
                return path

        else:
            raise NotImplementedError

        return activation_function_

    def pooling_function(pool_size):
        if pooling == "max":
            pooling_function_ = MaxPooling3D(pool_size, pool_size)

        elif pooling == "avg":
            pooling_function_ = AveragePooling3D(pool_size, pool_size)

        else:
            raise NotImplementedError

        return pooling_function_

    def upsampling_function(upsample_size):
        if upsampling == "copy":
            upsampling_function_ = UpSampling3D(upsample_size)

        elif upsampling == "linear":

            def upsampling_function_(path):
                path = UpSampling3D(upsample_size)(path)
                path = AveragePooling3D(upsample_size,
                                        strides=(1, 1, 1),
                                        padding='valid')(path)
                return path

        elif upsampling == "conv":

            def upsampling_function_(path):
                path = Conv3DTranspose(
                    K.int_shape(path)[-1], upsample_size, upsample_size)(path)
                return path

        else:
            raise NotImplementedError

        return upsampling_function_

    # Define some in-house variables
    nb_pathways = len(subsample_factors_per_pathway)
    supported_activations = ["relu", "lrelu", "prelu", "linear"]
    supported_poolings = ["max", "avg"]
    supported_upsamplings = ["copy", "linear", "conv"]
    supported_paddings = ["valid", "same"]

    # Do some sanity checks
    if not len(kernel_sizes_per_pathway) == len(
            number_features_per_pathway) == nb_pathways:
        raise ValueError("Inconsistent number of pathways.")

    for p in range(nb_pathways - 1):
        if not [
                f_next % f
                for f, f_next in zip(subsample_factors_per_pathway[p],
                                     subsample_factors_per_pathway[p + 1])
        ] == [0, 0, 0]:
            raise ValueError(
                "Each subsample factor must be a integer factor of the subsample factor of the previous pathway."
            )

    for k_s_p_p, n_f_p_p in zip(kernel_sizes_per_pathway,
                                number_features_per_pathway):
        if not len(k_s_p_p) == len(n_f_p_p) == 2:
            raise ValueError(
                "Each element in kernel_sizes_per_pathway and number_features_per_pathway must be a list of two elements, giving information about the left (downwards) and right (upwards) paths of the U-Net."
            )

        for k_s, n_f in zip(k_s_p_p, n_f_p_p):
            if not len(k_s) == len(n_f):
                raise ValueError(
                    "Each kernel size in each element from kernel_sizes_per_pathway must correspond with a number of features in each element of number_features_per_pathway."
                )

    if not len(kernel_sizes_common_pathway) == len(
            dropout_common_pathway) == len(number_features_common_pathway):
        raise ValueError("Inconsistent depth of common pathway.")

    if metadata_sizes in [None, []]:
        metadata_sizes = []
        if metadata_number_features in [None, []]:
            metadata_number_features = []

        else:
            raise ValueError(
                "Invalid value for metadata_number_features when there is no metadata"
            )

        if metadata_dropout in [None, []]:
            metadata_dropout = []

        else:
            raise ValueError(
                "Invalid value for metadata_dropout when there is no metadata")

        if metadata_at_common_pathway_layer in [None, []]:
            metadata_at_common_pathway_layer = []

        else:
            raise ValueError(
                "Invalid value for metadata_at_common_pathway_layer when there is no metadata"
            )

    else:
        if not len(metadata_sizes) == len(metadata_dropout) == len(
                metadata_number_features) == len(
                    metadata_at_common_pathway_layer):
            raise ValueError("Inconsistent depth of metadata pathway.")

    if extra_output_number_features in [None, [], ()]:
        extra_output_number_features = ()
        if extra_output_kernel_sizes in [None, [], ()]:
            extra_output_kernel_sizes = ()

        else:
            raise ValueError(
                "Invalid value for extra_output_kernel_sizes when there is no extra_output"
            )

        if extra_output_dropout in [None, [], ()]:
            extra_output_dropout = ()

        else:
            raise ValueError(
                "Invalid value for extra_output_dropout when there is no extra_output"
            )

        if extra_output_at_common_pathway_layer in [None, [], ()]:
            extra_output_at_common_pathway_layer = ()

        else:
            raise ValueError(
                "Invalid value for extra_output_at_common_pathway_layer when there is no extra_output"
            )

        if extra_output_activation_final_layer in [None, [], ()]:
            extra_output_activation_final_layer = ()

        else:
            raise ValueError(
                "Invalid value for extra_output_activation_final_layer when there is no extra_output"
            )

    if not len(extra_output_activation_final_layer) == len(
            extra_output_dropout) == len(extra_output_number_features) == len(
                extra_output_kernel_sizes) == len(
                    extra_output_at_common_pathway_layer):
        raise ValueError("Inconsistent depth of extra_output pathway.")

    if residual_connections and dense_connections:
        raise ValueError(
            "Residual connections and Dense connections should not be used together."
        )

    if dynamic_input_shapes and (residual_connections or dense_connections):
        raise NotImplementedError(
            "Currently using residual or dense connections are not supported when using dynamic input shapes!"
        )

    if dense_connections and pooling != "avg":
        raise ValueError(
            "According to Huang et al. a densely connected network should have average pooling."
        )

    if activation not in supported_activations:
        raise ValueError("The chosen activation is not supported.")

    if pooling not in supported_poolings:
        raise ValueError("The chosen pooling is not supported.")

    if upsampling not in supported_upsamplings:
        raise ValueError("The chosen upsampling is not supported.")

    if padding not in supported_paddings:
        raise ValueError("The chosen padding is not supported.")

    if (batch_normalization_on_input
            or batch_normalization) and (instance_normalization_on_input
                                         or instance_normalization):
        raise ValueError(
            "You have to choose between batch or instance normalization.")

    if relaxed_normalization_scheme and not (batch_normalization
                                             or instance_normalization):
        raise ValueError(
            "The relaxed normalization scheme can only be used if you also do (batch or instance) normalization."
        )

    # Calculate the field of view
    field_of_view = np.ones(3, dtype=int)
    if input_interpolation == "mean":
        field_of_view *= np.array(subsample_factors_per_pathway[0])

    for s_f_p_p, k_s_p_p in zip(subsample_factors_per_pathway,
                                kernel_sizes_per_pathway):
        for k_s in k_s_p_p[0]:
            field_of_view += (np.array(k_s) - 1) * s_f_p_p

    for s_f_p_p, k_s_p_p in reversed(
            list(zip(subsample_factors_per_pathway,
                     kernel_sizes_per_pathway))):
        for k_s in k_s_p_p[1]:
            field_of_view += (np.array(k_s) - 1) * s_f_p_p

    for k_s in kernel_sizes_common_pathway:
        field_of_view += (np.array(k_s) - 1) * subsample_factors_per_pathway[0]

    input_size = list(field_of_view - 1 + output_size)
    field_of_view = list(field_of_view)
    output_size = list(output_size)
    if verbose:
        print("\nfield of view:\t{}\t(theoretical)".format(field_of_view))
        print("output size:\t{}\t(user defined)".format(output_size))
        print(
            "input size:\t{}\t(inferred with theoretical field of view (less meaningful if padding='same'))"
            .format(input_size))

    # What are the possible input and output sizes?
    path_left_output_positions = []
    path_right_input_positions = []
    path_left_output_sizes = []
    path_right_input_sizes = []
    input_sizes = output_sizes = np.stack([np.arange(500)] * 3, axis=-1)
    prev_s_f_p_p = np.ones(3)
    prev_positions = np.zeros(3)
    for i, (s_f_p_p, k_s_p_p) in enumerate(
            zip(((1, 1, 1), ) + subsample_factors_per_pathway,
                (((), ()), ) + kernel_sizes_per_pathway)):
        s_f = s_f_p_p // prev_s_f_p_p
        prev_s_f_p_p = np.array(s_f_p_p)
        prev_positions = np.abs(prev_positions - (1 - (s_f_p_p *
                                                       (s_f - 1) + 1) % 2))
        output_sizes = (output_sizes // s_f) * ((output_sizes % s_f) == 0)
        for k_s in k_s_p_p[0]:
            output_sizes -= np.array(k_s) - 1 if padding == 'valid' else 0
            prev_positions = np.abs(prev_positions -
                                    (1 - (s_f_p_p *
                                          (np.array(k_s) - 1) + 1) % 2))

        path_left_output_positions.append(prev_positions)
        path_left_output_sizes.append(output_sizes)

    prev_s_f_p_p = np.array(subsample_factors_per_pathway[-1])
    for i, (s_f_p_p, k_s_p_p) in reversed(
            list(
                enumerate(
                    zip(((1, 1, 1),
                         (1, 1, 1)) + subsample_factors_per_pathway[:-1],
                        (((), ()), ) + kernel_sizes_per_pathway)))):
        path_right_input_positions.append(prev_positions)
        path_right_input_sizes.append(output_sizes)
        output_sizes = output_sizes - np.abs(path_left_output_positions[i] -
                                             prev_positions)
        output_sizes *= ((path_left_output_sizes[i] - output_sizes) % 2) == 0
        prev_positions = path_left_output_positions[i]
        for k_s in k_s_p_p[1]:
            output_sizes -= (np.array(k_s) - 1) if padding == 'valid' else 0
            prev_positions = np.abs(prev_positions -
                                    (1 - (s_f_p_p *
                                          (np.array(k_s) - 1) + 1) % 2))

        s_f = prev_s_f_p_p // s_f_p_p
        prev_s_f_p_p = np.array(s_f_p_p)
        output_sizes *= s_f
        output_sizes = output_sizes - s_f + 1 if upsampling == "linear" else output_sizes

    for k_s in kernel_sizes_common_pathway:
        output_sizes -= (np.array(k_s) - 1) if padding == 'valid' else 0

    path_right_input_positions.reverse()
    possible_input_sizes = [
        list(input_sizes[output_sizes[:, i] > 0, i]) for i in range(3)
    ]
    possible_output_sizes = [
        list(output_sizes[output_sizes[:, i] > 0, i].astype('int'))
        for i in range(3)
    ]
    possible_path_left_output_sizes, possible_path_right_input_sizes, possible_crops = [], [], []
    for i in range(3):
        possible_path_left_output_sizes_, possible_path_right_input_sizes_, possible_crops_ = [], [], []
        for j, o_s in enumerate(output_sizes[:, i]):
            if o_s > 0:
                possible_path_left_output_sizes_.append([
                    path_left_output_sizes_[j, i]
                    for path_left_output_sizes_ in path_left_output_sizes
                ])
                possible_path_right_input_sizes_.append([
                    path_right_input_sizes_[j, i] for path_right_input_sizes_
                    in reversed(path_right_input_sizes)
                ])
                possible_crops_.append([
                    int((pplos - ppris) / 2) for pplos, ppris in zip(
                        possible_path_left_output_sizes_[-1],
                        possible_path_right_input_sizes_[-1])
                ])

        possible_path_left_output_sizes.append(
            possible_path_left_output_sizes_)
        possible_path_right_input_sizes.append(
            possible_path_right_input_sizes_)
        possible_crops.append(possible_crops_)

    if not [
            o_s in p_o_s
            for o_s, p_o_s in zip(output_size, possible_output_sizes)
    ] == [True] * 3:
        print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format(
            *possible_output_sizes))
        print(
            "\npossible input sizes (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}"
            .format(*possible_input_sizes))
        raise ValueError(
            "The user defined output_size is not possible. Please choose from list above."
        )

    elif verbose:
        input_size = [
            possible_input_sizes[i][possible_output_sizes[i].index(o_s)]
            for i, o_s in enumerate(output_size)
        ]
        print("input size:\t{}\t(true input size of the network)".format(
            input_size))
        print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format(
            *possible_output_sizes))
        print(
            "\npossible input sizes (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}"
            .format(*possible_input_sizes))
        crops = [
            possible_crops[i][possible_output_sizes[i].index(output_size[i])]
            for i in range(3)
        ]
        indices_with_identical_cropping = []
        for i in range(3):
            indices_with_identical_cropping_ = []
            for j, crops_ in enumerate(possible_crops[i]):
                if all(
                    [crop == crop_ for crop, crop_ in zip(crops[i], crops_)]):
                    indices_with_identical_cropping_.append(j)

            indices_with_identical_cropping.append(
                indices_with_identical_cropping_)

        print(
            "\npossible output sizes when using dynamic shapes (based on output size {}):\nx: {}\ny: {}\nz: {}"
            .format(
                output_size, *[[
                    possible_output_sizes[i][j]
                    for j in indices_with_identical_cropping[i]
                ] for i in range(3)]))
        print(
            "\npossible input sizes when using dynamic shapes (based on output size {}):\nx: {}\ny: {}\nz: {}"
            .format(
                output_size, *[[
                    possible_input_sizes[i][j]
                    for j in indices_with_identical_cropping[i]
                ] for i in range(3)]))

    # Construct model
    inputs = []
    paths = []
    metadata_inputs = [None] * len(metadata_at_common_pathway_layer)

    # 1. Construct U part
    if dynamic_input_shapes:
        input_ = path = Input(shape=(None, None, None, number_input_features),
                              name="siam0_input")

    else:
        input_ = path = Input(shape=list(input_size) + [number_input_features],
                              name="siam0_input")

    inputs.append(input_)
    path_left_output_paths = []
    path_right_output_paths = []
    # Downwards
    prev_s_f_p_p = np.ones(3)
    for i, (s_f_p_p, k_s_p_p, n_f_p_p) in enumerate(
            zip(((1, 1, 1), ) + subsample_factors_per_pathway,
                (((), ()), ) + kernel_sizes_per_pathway,
                (((), ()), ) + number_features_per_pathway)):
        s_f = s_f_p_p // prev_s_f_p_p
        prev_s_f_p_p = np.array(s_f_p_p)
        path = pooling_function(s_f)(path)
        if i > 0:
            if i == 1 and (batch_normalization_on_input
                           or instance_normalization_on_input):
                path = normalization_function()(path)

            for j, (k_s, n_f) in enumerate(zip(k_s_p_p[0], n_f_p_p[0])):
                if j == 0 and (residual_connections or dense_connections):
                    shortcut = path

                elif (0 < j < len(k_s_p_p[0]) - 1) and dense_connections:
                    shortcut = Cropping3D([
                        int((l - r) / 2) for l, r in zip(
                            K.int_shape(shortcut)[1:-1],
                            K.int_shape(path)[1:-1])
                    ])(shortcut)
                    shortcut = Concatenate(axis=-1)([path, shortcut])
                    path = shortcut

                if ((i > 1 or j > 0) and
                    (batch_normalization or instance_normalization)) and (
                        not relaxed_normalization_scheme
                        or j == len(k_s_p_p[0]) - 1):
                    path = normalization_function()(path)

                if i > 1 or j > 0:
                    path = activation_function("down_p{}_activation{}".format(
                        i, j))(path)

                path = Conv3D(filters=n_f,
                              kernel_size=k_s,
                              padding=padding,
                              kernel_initializer=kernel_initializer,
                              kernel_regularizer=regularizers.l1_l2(
                                  l1_reg, l2_reg),
                              name="{}_{}".format(i, j))(path)
                if j == len(k_s_p_p[0]) - 1 and residual_connections:
                    shortcut = Cropping3D([
                        int((l - r) / 2) for l, r in zip(
                            K.int_shape(shortcut)[1:-1],
                            K.int_shape(path)[1:-1])
                    ])(shortcut)
                    if K.int_shape(path)[-1] != K.int_shape(shortcut)[-1]:
                        shortcut = Conv3D(
                            filters=K.int_shape(path)[-1],
                            kernel_size=(1, 1, 1),
                            padding=padding,
                            kernel_initializer=kernel_initializer,
                            kernel_regularizer=regularizers.l1_l2(
                                l1_reg, l2_reg))(shortcut)

                    path = Add()([path, shortcut])
        path_left_output_paths.append(path)
        paths.append(path)

        # Metadata
    for i, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer):
        if m_a_c_p_l == "x":
            assert number_siam_pathways == 1, "Insertion of metadata at 'x' is currently not supported with the use of siam pathways."
            path = introduce_metadata(path, i)

        # Upwards
    prev_s_f_p_p = np.array(subsample_factors_per_pathway[-1])
    for i, (s_f_p_p, k_s_p_p, n_f_p_p) in reversed(
            list(
                enumerate(
                    zip(((1, 1, 1),
                         (1, 1, 1)) + subsample_factors_per_pathway[:-1],
                        (((), ()), ) + kernel_sizes_per_pathway,
                        (((), ()), ) + number_features_per_pathway)))):
        path = AveragePooling3D(np.abs(path_left_output_positions[i] -
                                       path_right_input_positions[i]) + 1,
                                strides=(1, 1, 1))(path)
        if i > 0:
            if i < nb_pathways:
                path_left = Cropping3D([crops_[i] for crops_ in crops
                                        ])(path_left_output_paths[i])
                path = Concatenate(axis=-1)([path_left, path])

            for j, (k_s, n_f) in enumerate(zip(k_s_p_p[1], n_f_p_p[1])):
                if j == 0 and (residual_connections or dense_connections):
                    shortcut = path

                elif (0 < j < len(k_s_p_p[1]) - 1) and dense_connections:
                    shortcut = Cropping3D([
                        int((l - r) / 2) for l, r in zip(
                            K.int_shape(shortcut)[1:-1],
                            K.int_shape(path)[1:-1])
                    ])(shortcut)
                    shortcut = Concatenate(axis=-1)([path, shortcut])
                    path = shortcut

                if (batch_normalization or instance_normalization) and (
                        not relaxed_normalization_scheme
                        or j == len(k_s_p_p[1]) - 1):
                    path = normalization_function()(path)

                path = activation_function("up_p{}_activation{}".format(
                    i, j))(path)
                path = Conv3D(filters=n_f,
                              kernel_size=k_s,
                              padding=padding,
                              kernel_initializer=kernel_initializer,
                              kernel_regularizer=regularizers.l1_l2(
                                  l1_reg, l2_reg))(path)
                if j == len(k_s_p_p[1]) - 1 and residual_connections:
                    shortcut = Cropping3D([
                        int((l - r) / 2) for l, r in zip(
                            K.int_shape(shortcut)[1:-1],
                            K.int_shape(path)[1:-1])
                    ])(shortcut)
                    if K.int_shape(path)[-1] != K.int_shape(shortcut)[-1]:
                        shortcut = Conv3D(
                            filters=K.int_shape(path)[-1],
                            kernel_size=(1, 1, 1),
                            padding=padding,
                            kernel_initializer=kernel_initializer,
                            kernel_regularizer=regularizers.l1_l2(
                                l1_reg, l2_reg))(shortcut)

                    path = Add()([path, shortcut])

        path_right_output_paths.append(path)
        s_f = prev_s_f_p_p // s_f_p_p
        prev_s_f_p_p = np.array(s_f_p_p)
        path = upsampling_function(s_f)(path)
        paths.append(path)

    # SIAMESE networks
    if number_siam_pathways > 1:
        outputs = [path]
        siam_model = Model(inputs=inputs, outputs=path)
        for i in range(number_siam_pathways - 1):
            inputs_ = []
            if dynamic_input_shapes:
                input__ = Input(shape=(None, None, None,
                                       number_input_features),
                                name=f"siam{i + 1}_input")

            else:
                input_ = Input(shape=list(input_size) +
                               [number_input_features],
                               name=f"siam{i + 1}_input")

            inputs_.append(input_)
            inputs.append(input_)
            for j, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer):
                if m_a_c_p_l == "x":
                    meta_input_ = Input(shape=(1, 1, 1, metadata_sizes[j]),
                                        name=f"siam{i + 1}_meta{j + 1}")
                    inputs_.append(meta_input_)
                    inputs.append(meta_input_)

            outputs.append(
                siam_model(inputs_ if len(inputs_) > 1 else inputs_[0]))

        path = Concatenate(axis=-1)(outputs)

    outputs = [None] * len(extra_output_at_common_pathway_layer)
    # 2. Construct common pathway
    for i, (n_f_c_p, k_s_c_p, d_c_p) in enumerate(
            zip(number_features_common_pathway, kernel_sizes_common_pathway,
                dropout_common_pathway)):
        if (batch_normalization or instance_normalization) and (
                i == 0 or not relaxed_normalization_scheme):
            path = normalization_function()(path)

        path = activation_function("c_activation{}".format(i))(path)
        for j, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer):
            if m_a_c_p_l == i:
                path = introduce_metadata(path, j)

        for j, e_o_a_c_p_l in enumerate(extra_output_at_common_pathway_layer):
            if e_o_a_c_p_l == i:
                outputs[j] = retrieve_extra_output(path, j)

        if d_c_p:
            path = Dropout(d_c_p)(path)

        path = Conv3D(filters=n_f_c_p,
                      kernel_size=k_s_c_p,
                      activation=activation_final_layer if i +
                      1 == len(number_features_common_pathway) else None,
                      padding=padding,
                      kernel_initializer=kernel_initializer,
                      kernel_regularizer=regularizers.l1_l2(l1_reg, l2_reg),
                      name="s0" if i +
                      1 == len(number_features_common_pathway) else None)(path)

    inputs = inputs + metadata_inputs
    outputs.insert(0, path)
    # 3. Mask the output (optionally)
    if mask_output:
        if dynamic_input_shapes:
            mask_input_ = mask_path = Input(shape=(None, None, None,
                                                   K.int_shape(path)[-1]))

        else:
            mask_input_ = mask_path = Input(shape=tuple(output_size) +
                                            (K.int_shape(path)[-1], ))

        inputs.append(mask_input_)
        for i, output in outputs:
            outputs[i] = Multiply()([output, mask_path])

    # 4. For example: Correct for segment sampling changes to P(X|Y) --> this adds an extra dimension because the correction is done inside loss function and weights are given with y_creator in extra dimension (can only be done for binary like this)
    if add_extra_dimension:
        for i, output in outputs:
            outputs[i] = K.expand_dims(output, -1)

    model = Model(inputs=inputs, outputs=outputs)

    # Final sanity check: were our calculations correct?
    if verbose:
        print("\nNetwork summary:")
        print(model.summary())
        model_input_shape = model.input_shape
        if not isinstance(model_input_shape, list):
            model_input_shape = [model_input_shape]

        if not dynamic_input_shapes:
            assert list(model_input_shape[0][1:-1]) == input_size
            print('With a batch size of {} this model needs {} GB on the GPU.'.
                  format(1, get_model_memory_usage(1, model)))

        else:
            print(
                "Since you are using dynamic_input_shapes=True we cannot calculate the memory usage, nor can we truely check the correctness of the shapes..."
            )

    return model
Exemplo n.º 17
0
def model(x, N_svd=0):
    up_scale = 4
    with tf.variable_scope('ASR'):
        input_shape = x.get_shape().as_list()
        chn_in = input_shape[4]
        chn_base = 6 * up_scale
        # shape is [batch, 6, 24, 64, 1]

        # Group 1
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv1_1')(x)
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv1_2')(h)
        h1 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1),
                             'SAME',
                             activation='relu',
                             name='deconv1')(h)
        h1 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h1)
        h1 = Conv3D(filters=chn_base,
                    kernel_size=(1, 1, 1),
                    activation='relu',
                    padding='SAME',
                    name='conv1_3')(h1)
        # shape is [batch, 6, 24, 64, chn_base]
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(1, 3, 3),
                   strides=(1, 2, 2),
                   activation='relu',
                   padding='SAME',
                   name='conv1_4')(h)
        # shape is [batch, 6, 12, 32, chn_base * 2]

        # Group 2
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv2_1')(h)
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv2_2')(h)
        h2 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1),
                             'SAME',
                             activation='relu',
                             name='deconv2')(h)
        h2 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h2)
        h2 = Conv3D(filters=chn_base * 2,
                    kernel_size=(1, 1, 1),
                    activation='relu',
                    padding='SAME',
                    name='conv2_3')(h2)
        # shape is [batch, 6, 12, 32, chn_base * 2]
        h = Conv3D(filters=chn_base * 4,
                   kernel_size=(1, 1, 3),
                   strides=(1, 1, 2),
                   activation='relu',
                   padding='SAME',
                   name='conv2_4')(h)
        # shape is [batch, 6, 12, 16, chn_base * 4]

        # Layer 3, shrinking
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(1, 1, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv3')(h)
        # shape is [batch, 6, 12, 16, chn_base * 2]

        # Group 4, Mapping
        for i in range(2):
            h = Conv3D(filters=chn_base * 2,
                       kernel_size=(3, 1, 3),
                       activation='relu',
                       padding='SAME',
                       name='conv4_1_' + str(i))(h)
            h = Conv3D(filters=chn_base * 2,
                       kernel_size=(3, 3, 1),
                       activation='relu',
                       padding='SAME',
                       name='conv4_2_' + str(i))(h)
        # shape is [batch, 6, 12, 16, chn_base * 2]

        # Group 5, Attention
        h = SAAM(h, up_scale=up_scale, N_svd=N_svd)

        # Layer 6, Expanding
        h = Conv3D(filters=chn_base * 4,
                   kernel_size=(1, 1, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv6')(h)
        # shape is [batch, 6, 12, 16, chn_base * 4]

        # Group 7
        h = Conv3DTranspose(chn_base * 2, (1, 1, 4), (1, 1, 2),
                            activation='relu',
                            padding='SAME',
                            name='conv7_1')(h)
        # shape is [batch, 16, 12, 32, chn_base * 2]
        h = tf.concat([h, h2], axis=-1)
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv7_2')(h)
        h = Conv3D(filters=chn_base * 2,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv7_3')(h)
        # shape is [batch, 16, 12, 32, chn_base * 2]

        # Group 8
        h = Conv3DTranspose(chn_base, (1, 4, 4), (1, 2, 2),
                            activation='relu',
                            padding='SAME',
                            name='conv8_1')(h)
        # shape is [batch, 16, 24, 64, chn_base]
        h = tf.concat([h, h1], axis=-1)
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 1, 3),
                   activation='relu',
                   padding='SAME',
                   name='conv8_2')(h)
        h = Conv3D(filters=chn_base,
                   kernel_size=(3, 3, 1),
                   activation='relu',
                   padding='SAME',
                   name='conv8_3')(h)
        # shape is [batch, 16, 24, 64, chn_base]

        # Group 9
        h = Conv3D(filters=chn_in,
                   kernel_size=(3, 3, 3),
                   padding='SAME',
                   name='conv9')(h)
    return h
Exemplo n.º 18
0
def conv3d(filters,
           latentDim,
           path,
           batch=False,
           dropout=False,
           filter_size=3):

    checkpoint_dir = os.path.dirname(path)
    model = Sequential()
    #encoder
    #input = 28 x 28 x 1 (wide and thin)
    model.add(
        InputLayer(input_shape=(config.NUM_CHANNELS, config.IMG_HEIGHT,
                                config.IMG_WIDTH, 1)))

    for f in filters:
        model.add(
            Conv3D(f,
                   filter_size,
                   strides=2,
                   activation='relu',
                   padding="same"))
        if (dropout): model.add(Dropout(0.2))
        if (batch): model.add(BatchNormalization())
        # model.add(MaxPool3D(pool_size=(1,2,2)))
    if (latentDim is not None):
        # model.add(Flatten())
        model.add(
            Conv3D(latentDim, 1, strides=1, activation='relu', padding="same"))
        if (batch): model.add(BatchNormalization())
    # model.add(Flatten())
    # model.add(Dense(latentDim))
    for f in reversed(filters):
        # apply a CONV_TRANSPOSE => RELU => BN operation
        model.add(
            Conv3DTranspose(f,
                            filter_size,
                            activation='relu',
                            strides=2,
                            padding="same"))
        if (dropout): model.add(Dropout(0.2))
        if (batch): model.add(BatchNormalization())

    model.add(Conv3D(1, filter_size, activation='sigmoid', padding='same'))
    if (config.NUM_CHANNELS % (2**len(filters)) != 0):
        dim = config.NUM_CHANNELS
        for i in range(len(filters)):
            if (dim % 2 != 0):
                dim = int(dim / 2)
                dim += 1
            else:
                dim = int(dim / 2)
        print(dim)
        croppingFactor = int(
            (dim * (2**len(filters)) - config.NUM_CHANNELS) / 2)
        model.add(
            Cropping3D(cropping=((croppingFactor, croppingFactor), (0, 0),
                                 (0, 0))))

    model.summary()

    model.compile(loss='mean_squared_error', optimizer=Adam())
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=path,
        # monitor='val_loss',
        save_weights_only=True,
        # save_best_only=True,
        verbose=1,
        save_freq='epoch')

    latest = tf.train.latest_checkpoint(checkpoint_dir)
    # print('latestdasdasdas')
    print(latest)
    if latest is not None:
        model.load_weights(latest)
        print('weights loaded')

    return model, cp_callback
Exemplo n.º 19
0
print('layer', len(hidden) - 1, ':', hidden[-1].shape, 'after conv conv bn')
print('...')

# up
for i in range(len(nFeatMapsList) - 1):
    nFeatMaps = nFeatMapsList[-i - 2]
    hidden.append(
        Conv3DTranspose(nFeatMaps, (3),
                        strides=(2),
                        padding='same',
                        activation='relu')(hidden[-1]))
    print('layer', len(hidden) - 1, ':', hidden[-1].shape, 'after upconv')
    toCrop = int((hidden[ccidx[-1 - i]].shape[1] - hidden[-1].shape[1]) // 2)
    hidden.append(
        concatenate([hidden[-1],
                     Cropping3D(toCrop)(hidden[ccidx[-1 - i]])]))
    print('layer',
          len(hidden) - 1, ':', hidden[-1].shape,
          'after concat with cropped layer %d' % ccidx[-1 - i])

    #     hidden.append(Dropout(0.5)(hidden[-1], training=t))

    hidden.append(
        Conv3D(nFeatMaps, (3), padding='valid', activation=None)(hidden[-1]))
    hidden.append(
        tf.compat.v1.layers.batch_normalization(hidden[-1], training=t))
    hidden.append(
        Conv3D(nFeatMaps, (3), padding='valid', activation=None)(hidden[-1]))
    hidden.append(
        tf.compat.v1.layers.batch_normalization(hidden[-1], training=t))
    hidden.append(Activation('relu')(hidden[-1]))
Exemplo n.º 20
0
def apetnet_vv5_onnx(input_tensor=None,
                     n_ind_layers=1,
                     n_common_layers=7,
                     n_kernels_ind=15,
                     n_kernels_common=30,
                     kernel_shape=(3, 3, 3),
                     add_final_relu=False,
                     debug=False):
    """ Stacked single channel version of apetnet
        
        For description of input parameters see apetnet

        The input_tensor argument is only used determine the input shape.
        If None the input shape us set to (32,16,16,1).
    """
    # define input (stacked PET and MRI image)
    if input_tensor is not None:
        ipt = Input(input_tensor.shape[1:5], name='input')
    else:
        ipt = Input(shape=(32, 16, 16, 1), name='input')

    # extract pet and mri image
    # - first image in order is pet
    ipt_dim_crop = int(ipt.shape[1] // 2)
    mri_image = Cropping3D(cropping=((ipt_dim_crop, 0), (0, 0), (0, 0)),
                           name='extract_mri')(ipt)
    pet_image = Cropping3D(cropping=((0, ipt_dim_crop), (0, 0), (0, 0)),
                           name='extract_pet')(ipt)

    # create the full model
    if not debug:
        # individual paths
        if n_ind_layers > 0:
            init_val_ind = RandomNormal(
                mean=0.0,
                stddev=np.sqrt(2 / (np.prod(kernel_shape) * n_kernels_ind)))

            pet_image_ind = pet_image
            mri_image_ind = mri_image

            for i in range(n_ind_layers):
                pet_image_ind = Conv3D(
                    n_kernels_ind,
                    kernel_shape,
                    padding='same',
                    name='conv3d_pet_ind_' + str(i),
                    kernel_initializer=init_val_ind)(pet_image_ind)
                pet_image_ind = PReLU(shared_axes=[1, 2, 3],
                                      name='prelu_pet_ind_' +
                                      str(i))(pet_image_ind)
                mri_image_ind = Conv3D(
                    n_kernels_ind,
                    kernel_shape,
                    padding='same',
                    name='conv3d_mri_ind_' + str(i),
                    kernel_initializer=init_val_ind)(mri_image_ind)
                mri_image_ind = PReLU(shared_axes=[1, 2, 3],
                                      name='prelu_mri_ind_' +
                                      str(i))(mri_image_ind)

            # concatenate inputs
            net = Concatenate(name='concat_0')([pet_image_ind, mri_image_ind])

        else:
            # concatenate inputs
            net = Concatenate(name='concat_0')([pet_image, mri_image])

        # common path
        init_val_common = RandomNormal(
            mean=0.0,
            stddev=np.sqrt(2 / (np.prod(kernel_shape) * n_kernels_common)))

        for i in range(n_common_layers):
            net = Conv3D(n_kernels_common,
                         kernel_shape,
                         padding='same',
                         name='conv3d_' + str(i),
                         kernel_initializer=init_val_common)(net)
            net = PReLU(shared_axes=[1, 2, 3], name='prelu_' + str(i))(net)

        # layers that adds all features
        net = Conv3D(1, (1, 1, 1),
                     padding='valid',
                     name='conv_final',
                     kernel_initializer=RandomNormal(mean=0.0,
                                                     stddev=np.sqrt(2)))(net)

        # add pet_image to prediction
        net = Add(name='add_0')([net, pet_image])

        # ensure that output is non-negative
        if add_final_relu:
            net = ReLU(name='final_relu')(net)

    # in debug mode only add up pet and mri image
    else:
        net = Concatenate(name='add_0')([pet_image, mri_image])

    # create model
    model = Model(inputs=ipt, outputs=net)

    # return the model
    return model