def _pad_or_crop_to_shape_3D(x, in_shape, tgt_shape): ''' in_shape, tgt_shape are both 2x1 numpy arrays ''' im_diff = np.asarray(in_shape[:3]) - np.asarray(tgt_shape[:3]) if im_diff[0] < 0: pad_amt = (int(np.ceil(abs(im_diff[0]) / 2.0)), int(np.floor(abs(im_diff[0]) / 2.0))) x = ZeroPadding3D((pad_amt, (0, 0), (0, 0)))(x) if im_diff[1] < 0: pad_amt = (int(np.ceil(abs(im_diff[1]) / 2.0)), int(np.floor(abs(im_diff[1]) / 2.0))) x = ZeroPadding3D(((0, 0), pad_amt, (0, 0)))(x) if im_diff[2] < 0: pad_amt = (int(np.ceil(abs(im_diff[2]) / 2.0)), int(np.floor(abs(im_diff[2]) / 2.0))) x = ZeroPadding3D(((0, 0), (0, 0), pad_amt))(x) if im_diff[0] > 0: crop_amt = (int(np.ceil(im_diff[0] / 2.0)), int(np.floor(im_diff[0] / 2.0))) x = Cropping3D((crop_amt, (0, 0), (0, 0)))(x) if im_diff[1] > 0: crop_amt = (int(np.ceil(im_diff[1] / 2.0)), int(np.floor(im_diff[1] / 2.0))) x = Cropping3D(((0, 0), crop_amt, (0, 0)))(x) if im_diff[2] > 0: crop_amt = (int(np.ceil(im_diff[2] / 2.0)), int(np.floor(im_diff[2] / 2.0))) x = Cropping3D(((0, 0), (0, 0), crop_amt))(x) return x
def make_flood_fill_network(input_fov_shape, output_fov_shape, network_config): """Construct a stacked convolution module flood filling network. """ if network_config.convolution_padding != 'same': raise ValueError('ResNet implementation only supports same padding.') image_input = Input(shape=tuple(input_fov_shape) + (1, ), dtype='float32', name='image_input') if network_config.rescale_image: ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input) else: ffn = image_input mask_input = Input(shape=tuple(input_fov_shape) + (1, ), dtype='float32', name='mask_input') ffn = concatenate([ffn, mask_input]) # Convolve and activate before beginning the skip connection modules, # as discussed in the Appendix of He et al 2016. ffn = Conv3D(network_config.convolution_filters, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding='same')(ffn) if network_config.batch_normalization: ffn = BatchNormalization()(ffn) contraction = (input_fov_shape - output_fov_shape) // 2 if np.any(np.less(contraction, 0)): raise ValueError( 'Output FOV shape can not be larger than input FOV shape.') contraction_cumu = np.zeros(3, dtype=np.int32) contraction_step = np.divide(contraction, float(network_config.num_modules)) for i in range(0, network_config.num_modules): ffn = add_convolution_module(ffn, network_config) contraction_dims = np.floor(i * contraction_step - contraction_cumu).astype(np.int32) if np.count_nonzero(contraction_dims): ffn = Cropping3D( zip(list(contraction_dims), list(contraction_dims)))(ffn) contraction_cumu += contraction_dims if np.any(np.less(contraction_cumu, contraction)): remainder = contraction - contraction_cumu ffn = Cropping3D(zip(list(remainder), list(remainder)))(ffn) mask_output = Conv3D(1, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, padding='same', name='mask_output', activation=network_config.output_activation)(ffn) ffn = Model(inputs=[image_input, mask_input], outputs=[mask_output]) return ffn
def func(input_tensor, skip_tensor): # x = Conv2DTranspose(filters, kernel_size=(2, 2), strides=(2, 2))(input_tensor) x = _bn_act_deconv_3D(filters=filters, kernel_size=(2, 2, 2), strides=(2, 2, 2))(input_tensor) # computation of cropping-amount needed for skip_tensor _, x_height, x_width, x_depth, _ = K.int_shape(x) _, s_height, s_width, s_depth, _ = K.int_shape(skip_tensor) h_crop = s_height - x_height w_crop = s_width - x_width d_crop = s_depth - x_depth assert h_crop >= 0 assert w_crop >= 0 assert d_crop >= 0 if h_crop == 0 and w_crop == 0 and d_crop == 0: y = skip_tensor else: cropping = ((h_crop // 2, h_crop - h_crop // 2), (w_crop // 2, w_crop - w_crop // 2), (d_crop // 2, d_crop - d_crop // 2)) y = Cropping3D(cropping=cropping)(skip_tensor) x = Concatenate()([x, y]) x_list = [] for ii in range(len(dilation_rates)): dilation_rate = dilation_rates[ii] x = _bn_act_conv_3D(filters=filters, kernel_size=kernel_size, strides=strides, dilation_rate=dilation_rate, activation=activation, batch_norm=batch_norm, dropout=dropout, train=train)(x) x_list.append(x) return x, x_list
def deconv3d(layer_input, skip_input, filters, f_size=4, dropout_rate='self'): """Layers used during upsampling""" u = UpSampling3D(size=2)(layer_input) u = Conv3D(filters, kernel_size=f_size, strides=1, padding='same')(u) if dropout_rate == 'self': dropout_rate = 0.5 if self.dropout else False if dropout_rate: u = Dropout(dropout_rate)(u) if self.adain: g = Dense(filters, bias_initializer='ones')(w) b = Dense(filters)(w) u = Lambda(adain)([u, g, b]) else: u = BatchNormalization(momentum=0.8)(u) u = ReLU()(u) # u = Concatenate()([u, skip_input]) ch, cw, cd = get_crop_shape(u, skip_input) crop_conv4 = Cropping3D(cropping=(ch, cw, cd), data_format="channels_last")(u) u = Concatenate()([crop_conv4, skip_input]) return u
def crop_to_fit(inputs, outputs, n_layers=4): width = int_shape(inputs)[1] height = int_shape(inputs)[2] depth = int_shape(inputs)[3] w_pad, h_pad, d_pad = calc_fit_pad(width, height, depth, n_layers) x = Cropping3D((w_pad, h_pad, d_pad))(outputs) return x
def SAAM(x, up_scale, N_svd): bsize, ang, hei, wid, chn = x.get_shape().as_list() # \phi'_q h1 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_1')(x) h1 = tf.transpose(h1, [0, 2, 1, 3, 4]) h1 = tf.reshape(h1, [-1, ang * wid, chn // 8]) if N_svd > 0: s1, u1, v1 = tf.svd(h1) s1 = tf.slice(s1, [0, 0], [-1, N_svd]) u1 = tf.slice(u1, [0, 0, 0], [-1, -1, N_svd]) v1 = tf.slice(v1, [0, 0, 0], [-1, -1, N_svd]) # \phi'_k h2 = Conv3D(chn // 8, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_2')(x) h2 = tf.transpose(h2, [0, 2, 1, 3, 4]) h2 = tf.reshape(h2, [-1, ang * wid, chn // 8]) if N_svd > 0: s2, u2, v2 = tf.svd(h2) s2 = tf.slice(s2, [0, 0], [-1, N_svd]) u2 = tf.slice(u2, [0, 0, 0], [-1, -1, N_svd]) v2 = tf.slice(v2, [0, 0, 0], [-1, -1, N_svd]) # \phi'_v h3 = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_3')(x) h3 = tf.transpose(h3, [0, 2, 1, 3, 4]) h3 = tf.reshape(h3, [-1, ang * wid, chn // 8 * up_scale]) # Map if N_svd > 0: attn_EPI = tf.matmul(tf.matmul(u1, tf.matmul(tf.matmul(tf.matrix_diag(s1), tf.matmul(v1, v2, transpose_a=True)), tf.matrix_diag(s2), transpose_b=True)), u2, transpose_b=True) else: attn_EPI = tf.matmul(h1, h2, transpose_b=True) attn_EPI = tf.nn.softmax(attn_EPI) # \phi'_a h = tf.matmul(attn_EPI, h3) # \phi'_b h = tf.reshape(h, [-1, hei, ang, wid, chn // 8 * up_scale]) h = tf.transpose(h, [0, 2, 1, 3, 4]) x = Conv3D(chn // 8 * up_scale, (1, 1, 1), use_bias=False, padding='SAME', name='attn_epi_4')(x) sigma = beta_variable([1], name='attn_sigma') h = x + sigma * h # \phi_c h = pixel_shuffle(h, up_scale) h = Cropping3D(cropping=((0, up_scale - 1), (0, 0), (0, 0)))(h) h = Conv3D(filters=chn, kernel_size=(1, 1, 7), activation='relu', padding='SAME', name='epi_5')(h) h = Conv3D(filters=chn, kernel_size=(7, 1, 1), activation='relu', padding='SAME', name='epi_6')(h) return h
def crop_nodes_to_match(self, node1, node2): """ If necessary, applies Cropping3D layer to node1 to match shape of node2 """ s1 = np.array(node1.get_shape().as_list())[1:-1] s2 = np.array(node2.get_shape().as_list())[1:-1] if np.any(s1 != s2): c = (s1 - s2).astype(np.int) cr = np.array([c // 2, c // 2]).T cr[:, 1] += c % 2 cropped_node1 = Cropping3D(cr)(node1) self.label_crop += cr else: cropped_node1 = node1 return cropped_node1
def create_simplified_deepmedic_model(**kwargs): from tensorflow.keras.layers import Input, AveragePooling3D, Cropping3D from tensorflow.keras import backend as K from tensorflow.keras import Model assert np.all([ n_i_f_p_p == kwargs["number_input_features_per_pathway"][0] for n_i_f_p_p in kwargs["number_input_features_per_pathway"] ]) if "input_interpolation" in kwargs: assert kwargs["input_interpolation"] == "mean" else: kwargs["input_interpolation"] = "mean" model = create_generalized_deepmedic_model(**kwargs) input_size = np.max( [[s_f_p_p[i] * K.int_shape(input_path)[i + 1] for i in range(3)] for s_f_p_p, input_path in zip( kwargs["subsample_factors_per_pathway"], model.inputs)], axis=0) input = Input( tuple(input_size) + (kwargs["number_input_features_per_pathway"][0], )) paths = [] for s_f_p_p, input_path in zip(kwargs["subsample_factors_per_pathway"], model.inputs): crop_size = [ input_size[i] - s_f_p_p[i] * K.int_shape(input_path)[i + 1] for i in range(3) ] path = Cropping3D( tuple([(c_z // 2, c_z - c_z // 2) for c_z in crop_size]))(input) path = AveragePooling3D(tuple(s_f_p_p))(path) paths.append(path) inputs = [input ] + model.inputs[len(kwargs["subsample_factors_per_pathway"]):] outputs = model( paths + model.inputs[len(kwargs["subsample_factors_per_pathway"]):]) model = Model(inputs=inputs, outputs=outputs if isinstance(outputs, list) else [outputs]) print("\nNetwork summary of simplified DeepMedic model:") print(model.summary()) return model
def func(input_tensor, skip_tensor): # x = Conv2DTranspose(filters, kernel_size=(2, 2), strides=(2, 2))(input_tensor) x = _deconv_bn_act_3D(filters=filters, kernel_size=(2, 2, 2), strides=(2, 2, 2))(input_tensor) # # computation of cropping-amount needed for skip_tensor _, x_height, x_width, x_depth, _ = K.int_shape(x) _, s_height, s_width, s_depth, _ = K.int_shape(skip_tensor) h_crop = s_height - x_height w_crop = s_width - x_width d_crop = s_depth - x_depth assert h_crop >= 0 assert w_crop >= 0 assert d_crop >= 0 if h_crop == 0 and w_crop == 0 and d_crop == 0: y = skip_tensor else: cropping = ((h_crop // 2, h_crop - h_crop // 2), (w_crop // 2, w_crop - w_crop // 2), (d_crop // 2, d_crop - d_crop // 2)) y = Cropping3D(cropping=cropping)(skip_tensor) # cropping = ((h_crop // 2, h_crop - h_crop // 2), (w_crop // 2, w_crop - w_crop // 2)) # y = Cropping2D(cropping=cropping)(skip_tensor) # commented out at the moment because of error message: # could not create a view primitive descriptor, in file tensorflow/core/kernels/mkl_slice_op.cc:300 x = Concatenate()([x, y]) x_list = [] for ii in range(len(dilation_rates)): dilation_rate = dilation_rates[ii] x = _conv_bn_act_3D(filters=filters, kernel_size=kernel_size, strides=strides, dilation_rate=dilation_rate, activation=activation, batch_norm=batch_norm, dropout=dropout)(x) x_list.append(x) return x, x_list
def test_delete_channels_cropping3d(channel_index, data_format): layer = Cropping3D([2, 3, 2], data_format=data_format) layer_test_helper_flatten_3d(layer, channel_index, data_format=data_format)
def bn_feature_net_3D(receptive_field=61, n_frames=5, input_shape=(5, 256, 256, 1), n_features=3, n_channels=1, reg=1e-5, n_conv_filters=64, n_dense_filters=200, VGG_mode=False, init='he_normal', norm_method='std', location=False, dilated=False, padding=False, padding_mode='reflect', multires=False, include_top=True, temporal=None, residual=False, temporal_kernel_size=3): """Creates a 3D featurenet. Args: receptive_field (int): the receptive field of the neural network. n_frames (int): Number of frames. input_shape (tuple): If no input tensor, create one with this shape. n_features (int): Number of output features n_channels (int): number of input channels reg (int): regularization value n_conv_filters (int): number of convolutional filters n_dense_filters (int): number of dense filters VGG_mode (bool): If ``multires``, uses ``VGG_mode`` for multiresolution init (str): Method for initalizing weights. norm_method (str): Normalization method to use with the :mod:`deepcell.layers.normalization.ImageNormalization3D` layer. location (bool): Whether to include a :mod:`deepcell.layers.location.Location3D` layer. dilated (bool): Whether to use dilated pooling. padding (bool): Whether to use padding. padding_mode (str): Type of padding, one of 'reflect' or 'zero' multires (bool): Enables multi-resolution mode include_top (bool): Whether to include the final layer of the model temporal (str): Type of temporal operation residual (bool): Whether to use temporal information as a residual temporal_kernel_size (int): size of 2D kernel used in temporal convolutions Returns: tensorflow.keras.Model: 3D FeatureNet """ # Create layers list (x) to store all of the layers. # We need to use the functional API to enable the multiresolution mode x = [] win = (receptive_field - 1) // 2 win_z = (n_frames - 1) // 2 if dilated: padding = True if K.image_data_format() == 'channels_first': channel_axis = 1 time_axis = 2 row_axis = 3 col_axis = 4 if not dilated: input_shape = (n_channels, n_frames, receptive_field, receptive_field) else: channel_axis = -1 time_axis = 1 row_axis = 2 col_axis = 3 if not dilated: input_shape = (n_frames, receptive_field, receptive_field, n_channels) x.append(Input(shape=input_shape)) x.append( ImageNormalization3D(norm_method=norm_method, filter_size=receptive_field)(x[-1])) if padding: if padding_mode == 'reflect': x.append(ReflectionPadding3D(padding=(win_z, win, win))(x[-1])) elif padding_mode == 'zero': x.append(ZeroPadding3D(padding=(win_z, win, win))(x[-1])) if location: x.append(Location3D()(x[-1])) x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]])) layers_to_concat = [] rf_counter = receptive_field block_counter = 0 d = 1 while rf_counter > 4: filter_size = 3 if rf_counter % 2 == 0 else 4 x.append( Conv3D(n_conv_filters, (1, filter_size, filter_size), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) block_counter += 1 rf_counter -= filter_size - 1 if block_counter % 2 == 0: if dilated: x.append( DilatedMaxPool3D(dilation_rate=(1, d, d), pool_size=(1, 2, 2))(x[-1])) d *= 2 else: x.append(MaxPool3D(pool_size=(1, 2, 2))(x[-1])) if VGG_mode: n_conv_filters *= 2 rf_counter = rf_counter // 2 if multires: layers_to_concat.append(len(x) - 1) if multires: c = [] for l in layers_to_concat: output_shape = x[l].get_shape().as_list() target_shape = x[-1].get_shape().as_list() time_crop = (0, 0) row_crop = int(output_shape[row_axis] - target_shape[row_axis]) if row_crop % 2 == 0: row_crop = (row_crop // 2, row_crop // 2) else: row_crop = (row_crop // 2, row_crop // 2 + 1) col_crop = int(output_shape[col_axis] - target_shape[col_axis]) if col_crop % 2 == 0: col_crop = (col_crop // 2, col_crop // 2) else: col_crop = (col_crop // 2, col_crop // 2 + 1) cropping = (time_crop, row_crop, col_crop) c.append(Cropping3D(cropping=cropping)(x[l])) x.append(Concatenate(axis=channel_axis)(c)) x.append( Conv3D(n_dense_filters, (1, rf_counter, rf_counter), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) x.append( Conv3D(n_dense_filters, (n_frames, 1, 1), dilation_rate=(1, d, d), kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) feature = Activation('relu')(x[-1]) def __merge_temporal_features(feature, mode='conv', residual=False, n_filters=256, n_frames=3, padding=True, temporal_kernel_size=3): if mode is None: return feature mode = str(mode).lower() if mode == 'conv': x = Conv3D(n_filters, (n_frames, temporal_kernel_size, temporal_kernel_size), kernel_initializer=init, padding='same', activation='relu', kernel_regularizer=l2(reg))(feature) elif mode == 'lstm': x = ConvLSTM2D(filters=n_filters, kernel_size=temporal_kernel_size, padding='same', kernel_initializer=init, activation='relu', kernel_regularizer=l2(reg), return_sequences=True)(feature) elif mode == 'gru': x = ConvGRU2D(filters=n_filters, kernel_size=temporal_kernel_size, padding='same', kernel_initializer=init, activation='relu', kernel_regularizer=l2(reg), return_sequences=True)(feature) else: raise ValueError( '`temporal` must be one of "conv", "lstm", "gru" or None') if residual is True: temporal_feature = Add()([feature, x]) else: temporal_feature = x temporal_feature_normed = BatchNormalization( axis=channel_axis)(temporal_feature) return temporal_feature_normed temporal_feature = __merge_temporal_features( feature, mode=temporal, residual=residual, n_filters=n_dense_filters, n_frames=n_frames, padding=padding, temporal_kernel_size=temporal_kernel_size) x.append(temporal_feature) x.append( TensorProduct(n_dense_filters, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) x.append( TensorProduct(n_features, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1])) if not dilated: x.append(Flatten()(x[-1])) if include_top: x.append(Softmax(axis=channel_axis, dtype=K.floatx())(x[-1])) model = Model(inputs=x[0], outputs=x[-1]) return model
def create_generalized_deepmedic_model( number_input_features_per_pathway=(1, 1), subsample_factors_per_pathway=((1, 1, 1), (3, 3, 3)), kernel_sizes_per_pathway=(((3, 3, 1), ) * 5 + ((3, 3, 3), ) * 5, ((3, 3, 1), ) * 5 + ((3, 3, 3), ) * 5), number_features_per_pathway=((32, ) * 5 + (48, ) * 5, (32, ) * 5 + (48, ) * 5), kernel_sizes_common_pathway=((1, 1, 1), ) * 3, number_features_common_pathway=(150, 150, 1), dropout_common_pathway=(0, 0.5, 0.5), output_size=(22, 15, 9), metadata_sizes=None, metadata_number_features=None, metadata_dropout=None, metadata_at_common_pathway_layer=None, padding='valid', pooling='avg', # Not used yet; set to 'avg' to allow use with dense_connection=<int> upsampling='copy', activation='prelu', activation_final_layer='sigmoid', kernel_initializer='he_normal', batch_normalization=False, batch_normalization_on_input=False, instance_normalization=False, instance_normalization_on_input=False, relaxed_normalization_scheme=False, # Every <int> layers we request normalization mask_output=False, residual_connections=False, # We group <int> layers into a residual block dense_connections=False, # We group <int> layers into a densely connected block (<int> - 1 layers have dense connections and there is always 1 transition layer in between) add_extra_dimension=False, l1_reg=0.0, l2_reg=0.0, verbose=True, input_interpolation="nearest"): from tensorflow.keras.layers import Input, Dropout, MaxPooling3D, Concatenate, Multiply, Add, Reshape, Conv3DTranspose, AveragePooling3D, Conv3D, UpSampling3D, Cropping3D, LeakyReLU, PReLU, BatchNormalization from tensorflow.keras import regularizers from tensorflow.keras import backend as K from tensorflow_addons.layers import InstanceNormalization from tensorflow.keras import Model # Define some in-house functions def normalization_function(): if batch_normalization_on_input or batch_normalization: normalization_function_ = BatchNormalization() elif instance_normalization_on_input or instance_normalization: normalization_function_ = InstanceNormalization() else: raise NotImplementedError return normalization_function_ def introduce_metadata(path, i): metadata_input_ = metadata_path = Input(shape=(1, 1, 1, metadata_sizes[i])) inputs.append(metadata_input_) for j, (m_n_f, m_d) in enumerate( zip(metadata_number_features[i], metadata_dropout[i])): if m_d: metadata_path = Dropout(m_d)(metadata_path) metadata_path = Conv3D(filters=m_n_f, kernel_size=(1, 1, 1), padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg))(metadata_path) metadata_path = activation_function("m{}_activation{}".format( i, j))(metadata_path) # When the metadata is inserted, every voxel most likely corresponds to an output. # Currently I cannot think of a good reason it wouldn't be the case, hence the hard assert. path_size = K.int_shape(path)[1:4] assert tuple(output_size) == tuple(path_size) metadata_path = UpSampling3D(tuple([int(s) for s in path_size ]))(metadata_path) return Concatenate(axis=-1)([path, metadata_path]) def activation_function(name): if activation == "relu": activation_function_ = LeakyReLU(alpha=0, name=name) elif activation == "lrelu": activation_function_ = LeakyReLU(alpha=0.01, name=name) elif activation == "prelu": activation_function_ = PReLU(shared_axes=[1, 2, 3], name=name) elif activation == "linear": def activation_function_(path): return path else: raise NotImplementedError return activation_function_ def pooling_function(pool_size): if pooling == "max": pooling_function_ = MaxPooling3D(pool_size, pool_size) elif pooling == "avg": pooling_function_ = AveragePooling3D(pool_size, pool_size) else: raise NotImplementedError return pooling_function_ def upsampling_function(upsample_size): if upsampling == "copy": upsampling_function_ = UpSampling3D(upsample_size) elif upsampling == "linear": def upsampling_function_(path): path = UpSampling3D(upsample_size)(path) path = AveragePooling3D(upsample_size, strides=(1, 1, 1), padding='valid')(path) return path elif upsampling == "conv": def upsampling_function_(path): path = Conv3DTranspose( K.int_shape(path)[-1], upsample_size, upsample_size)(path) return path else: raise NotImplementedError return upsampling_function_ # Define some in-house variables nb_pathways = len(subsample_factors_per_pathway) supported_activations = ["relu", "lrelu", "prelu", "linear"] supported_poolings = ["max", "avg"] supported_upsamplings = ["copy", "linear", "conv"] supported_paddings = ["valid", "same"] #Do some sanity checks if not len(number_input_features_per_pathway) == len( kernel_sizes_per_pathway) == len( number_features_per_pathway) == nb_pathways: raise ValueError("Inconsistent number of pathways.") for p in range(nb_pathways): if not len(kernel_sizes_per_pathway[p]) == len( number_features_per_pathway[p]): raise ValueError("Inconsistent depth of pathway #{}.".format(p)) for ssf in subsample_factors_per_pathway[p]: if ssf % 2 != 1: raise ValueError("Subsample factors must be odd.") for k_s_p_p, n_f_p_p in zip(kernel_sizes_per_pathway, number_features_per_pathway): if not len(k_s_p_p) == len(n_f_p_p): raise ValueError( "Each kernel size in each element from kernel_sizes_per_pathway must correspond with a number of features in each element of number_features_per_pathway." ) if not len(kernel_sizes_common_pathway) == len( dropout_common_pathway) == len(number_features_common_pathway): raise ValueError("Inconsistent depth of common pathway.") if metadata_sizes in [None, []]: metadata_sizes = [] if metadata_number_features in [None, []]: metadata_number_features = [] else: raise ValueError( "Invalid value for metadata_number_features when there is no metadata" ) if metadata_dropout in [None, []]: metadata_dropout = [] else: raise ValueError( "Invalid value for metadata_dropout when there is no metadata") if metadata_at_common_pathway_layer in [None, []]: metadata_at_common_pathway_layer = [] else: raise ValueError( "Invalid value for metadata_at_common_pathway_layer when there is no metadata" ) else: if not len(metadata_sizes) == len(metadata_dropout) == len( metadata_number_features) == len( metadata_at_common_pathway_layer): raise ValueError("Inconsistent depth of metadata pathway.") if residual_connections and dense_connections: raise ValueError( "Residual connections and Dense connections should not be used together." ) if dense_connections and not pooling == "avg": raise ValueError( "According to Huang et al. a densely connected network should have average pooling." ) if activation not in supported_activations: raise ValueError("The chosen activation is not supported.") if pooling not in supported_poolings: raise ValueError("The chosen pooling is not supported.") if upsampling not in supported_upsamplings: raise ValueError("The chosen upsampling is not supported.") if padding not in supported_paddings: raise ValueError("The chosen padding is not supported.") if (batch_normalization_on_input or batch_normalization) and (instance_normalization_on_input or instance_normalization): raise ValueError( "You have to choose between batch or instance normalization.") if relaxed_normalization_scheme and not (batch_normalization or instance_normalization): raise ValueError( "The relaxed normalization scheme can only be used if you also do (batch or instance) normalization." ) # Calculate the field of view field_of_views = [] input_sizes = [] for p, (s_f_p_p, k_s_p_p) in enumerate( zip(subsample_factors_per_pathway, kernel_sizes_per_pathway)): field_of_view = np.ones(3, dtype=int) if input_interpolation == "mean": field_of_view *= np.array(s_f_p_p) for k_s in k_s_p_p: field_of_view += (np.array(k_s) - 1) * s_f_p_p for k_s in kernel_sizes_common_pathway: field_of_view += np.array(k_s) - 1 field_of_views.append(list(field_of_view)) input_sizes.append(list(field_of_view - 1 + output_size)) output_size = list(output_size) if verbose: for p in range(nb_pathways): print( "\nfield of view for pathway {}:\t{}\t(theoretical (less meaningful if padding='same' and output size is small))" .format(p, field_of_views[p])) print("output size:\t{}\t(user defined)".format(output_size)) for p in range(nb_pathways): print( "input size for pathway {}:\t{}\t(inferred with theoretical field of view (less meaningful if padding='same'))" .format(p, input_sizes[p])) # What are the possible input and output sizes? input_sizes = output_sizes = np.stack([np.arange(150)] * 3, axis=-1) for k_s in reversed(kernel_sizes_common_pathway): input_sizes = input_sizes + (np.array(k_s) - 1 if padding == 'valid' else 0) input_sizes_per_pathway = [] sizes_per_pathway_after_upsampling = [] for p, (s_f_p_p, k_s_p_p) in enumerate( zip(subsample_factors_per_pathway, kernel_sizes_per_pathway)): sizes_p_before_upsampling = np.ceil( (input_sizes - (1 if upsampling == "linear" else 0)) / s_f_p_p) + (1 if upsampling == "linear" else 0) sizes_p_after_upsampling = ( sizes_p_before_upsampling - (1 if upsampling == "linear" else 0)) * s_f_p_p + ( 1 if upsampling == "linear" else 0) input_sizes_p = sizes_p_before_upsampling for k_s in k_s_p_p: input_sizes_p = input_sizes_p + (np.array(k_s) - 1 if padding == 'valid' else 0) input_sizes_per_pathway.append(input_sizes_p) sizes_per_pathway_after_upsampling.append(sizes_p_after_upsampling) if p > 0: output_sizes[(sizes_p_after_upsampling - sizes_per_pathway_after_upsampling[0]) % 2 > 0] = 0 possible_sizes_per_pathway_after_upsampling = [[ list(sizes_per_pathway_after_upsampling[p][output_sizes[:, i] > 0, i].astype(int)) for i in range(3) ] for p in range(nb_pathways)] possible_input_sizes_per_pathway = [[ list(input_sizes_per_pathway[p][output_sizes[:, i] > 0, i].astype(int)) for i in range(3) ] for p in range(nb_pathways)] possible_output_sizes = [ list(output_sizes[output_sizes[:, i] > 0, i].astype('int')) for i in range(3) ] if verbose and not all([ o_s in p_o_s for o_s, p_o_s in zip(output_size, possible_output_sizes) ]): print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format( *possible_output_sizes)) for p in range(nb_pathways): print( "\npossible input sizes for pathway {} (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}" .format(p, *possible_input_sizes_per_pathway[p])) raise ValueError( "The user defined output_size is not possible. Please choose from list above." ) else: input_sizes = [[ int(possible_input_sizes_per_pathway[p][i][ possible_output_sizes[i].index(o_s)]) for i, o_s in enumerate(output_size) ] for p in range(nb_pathways)] sizes_after_upsampling = [[ int(possible_sizes_per_pathway_after_upsampling[p][i][ possible_output_sizes[i].index(o_s)]) for i, o_s in enumerate(output_size) ] for p in range(nb_pathways)] for p in range(nb_pathways): print( "input size for pathway {}:\t{}\t(true input size of the network)" .format(p, input_sizes[p])) print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format( *possible_output_sizes)) for p in range(nb_pathways): print( "\npossible input sizes for pathway {} (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}" .format(p, *possible_input_sizes_per_pathway[p])) # Construct model inputs = [] paths = [] #1. Construct parallel pathways for p in range(nb_pathways): input_ = path = Input(shape=tuple(input_sizes[p]) + (number_input_features_per_pathway[p], ), name="p{}_input".format(p)) inputs.append(input_) for i, (k_s_p_p, n_f_p_p) in enumerate( zip(((), ) + kernel_sizes_per_pathway[p], ((), ) + number_features_per_pathway[p])): if i != 0: path = Conv3D(filters=n_f_p_p, kernel_size=k_s_p_p, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg))(path) if dense_connections: if i % dense_connections != 0: shortcut = Cropping3D([ int((l - r) / 2) for l, r in zip( K.int_shape(shortcut)[1:-1], K.int_shape(path)[1:-1]) ])(shortcut) path = Concatenate(axis=-1)([path, shortcut]) if i + 1 != len(kernel_sizes_per_pathway[p]): shortcut = path if residual_connections and i % residual_connections == 0: if i != 0: shortcut = Cropping3D([ int((l - r) / 2) for l, r in zip( K.int_shape(shortcut)[1:-1], K.int_shape(path)[1:-1]) ])(shortcut) if K.int_shape(path)[-1] != K.int_shape(shortcut)[-1]: shortcut = Conv3D( filters=K.int_shape(path)[-1], kernel_size=(1, 1, 1), padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg))(shortcut) path = Add()([path, shortcut]) if i + 1 != len(kernel_sizes_per_pathway[p]): shortcut = path if not relaxed_normalization_scheme or i % relaxed_normalization_scheme == 0: if (i == 0 and (batch_normalization_on_input or instance_normalization_on_input)) or ( i != 0 and (batch_normalization or instance_normalization)): path = normalization_function()(path) if i != 0: path = activation_function("p{}_activation{}".format(p, i - 1))(path) path = upsampling_function(subsample_factors_per_pathway[p])(path) if p > 0: path = Cropping3D([ int((l - r) / 2) for l, r in zip(sizes_after_upsampling[p], sizes_after_upsampling[0]) ])(path) paths.append(path) # 2. Construct common pathway path = Concatenate(axis=-1)(paths) if len(paths) > 1 else paths[0] for i, (n_f_c_p, k_s_c_p, d_c_p) in enumerate( zip(number_features_common_pathway, kernel_sizes_common_pathway, dropout_common_pathway)): for j, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer): if m_a_c_p_l == i: path = introduce_metadata(path, j) if d_c_p: path = Dropout(d_c_p)(path) path = Conv3D(filters=n_f_c_p, kernel_size=k_s_c_p, activation=activation_final_layer if i + 1 == len(number_features_common_pathway) else None, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2(l1_reg, l2_reg))(path) if i + 1 < len(number_features_common_pathway): if batch_normalization or instance_normalization: path = normalization_function()(path) path = activation_function("c_activation{}".format(i))(path) # 3. Mask the output (optionally) if mask_output: mask_input_ = mask_path = Input(shape=tuple(output_size) + (K.int_shape(path)[-1], )) inputs.append(mask_input_) path = Multiply()([path, mask_path]) # 4. For example: Correct for segment sampling changes to P(X|Y) --> this adds an extra dimension because the correction is done inside loss function and weights are given with y_creator in extra dimension (can only be done for binary like this) if add_extra_dimension: path = Reshape(K.int_shape(path)[1:] + (1, ))(path) model = Model(inputs=inputs, outputs=[path]) # Final sanity check: were our calculations correct? if verbose: print("\nNetwork summary:") print(model.summary()) model_input_shape = model.input_shape if not isinstance(model_input_shape, list): model_input_shape = [model_input_shape] for p in range(nb_pathways): assert list(model_input_shape[p][1:-1]) == input_sizes[p] print('With a batch size of {} this model needs {} GB on the GPU.'. format(1, get_model_memory_usage(1, model))) return model
def add_unet_layer(model, network_config, remaining_layers, output_shape, n_channels=None): if n_channels is None: n_channels = model.get_shape().as_list()[-1] downsample = np.array([ x != 0 and remaining_layers % x == 0 for x in network_config.unet_downsample_rate ]) if network_config.convolution_padding == 'same': conv_contract = np.zeros(3, dtype=np.int32) else: conv_contract = network_config.convolution_dim - 1 # First U convolution module. for i in range(network_config.num_layers_per_module): if i == network_config.num_layers_per_module - 1: # Increase the number of channels before downsampling to avoid # bottleneck (identical to 3D U-Net paper). n_channels = 2 * n_channels model = Conv3D(n_channels, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding=network_config.convolution_padding)(model) if network_config.batch_normalization: model = BatchNormalization()(model) # Crop and pass forward to upsampling. if remaining_layers > 0: forward_link_shape = output_shape + network_config.num_layers_per_module * conv_contract else: forward_link_shape = output_shape contraction = (np.array(model.get_shape().as_list()[1:4]) - forward_link_shape) // 2 forward = Cropping3D(list(zip(list(contraction), list(contraction))))(model) if network_config.dropout_probability > 0.0: forward = Dropout(network_config.dropout_probability)(forward) # Terminal layer of the U. if remaining_layers <= 0: return forward # Downsample and recurse. model = Conv3D(n_channels, tuple(network_config.convolution_dim), strides=list(downsample + 1), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding='same')(model) if network_config.batch_normalization: model = BatchNormalization()(model) next_output_shape = np.ceil( np.divide(forward_link_shape, downsample.astype(np.float32) + 1.0)).astype(np.int32) model = add_unet_layer(model, network_config, remaining_layers - 1, next_output_shape.astype(np.int32)) # Upsample output of previous layer and merge with forward link. model = Conv3DTranspose(n_channels * 2, tuple(network_config.convolution_dim), strides=list(downsample + 1), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding='same')(model) if network_config.batch_normalization: model = BatchNormalization()(model) # Must crop output because Keras wrongly pads the output shape for odd array sizes. stride_pad = (network_config.convolution_dim // 2) * np.array(downsample) + (1 - np.mod(forward_link_shape, 2)) tf_pad_start = stride_pad // 2 # Tensorflow puts odd padding at end. model = Cropping3D( list(zip(list(tf_pad_start), list(stride_pad - tf_pad_start))))(model) model = concatenate([forward, model]) # Second U convolution module. for _ in range(network_config.num_layers_per_module): model = Conv3D(n_channels, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding=network_config.convolution_padding)(model) if network_config.batch_normalization: model = BatchNormalization()(model) return model
def get_net(): # Level 1 input = Input((input_dim, input_dim, input_dim, 1)) conv1 = Conv3D(32, (3, 3, 3), activation="relu", padding="same")(input) batch1 = BatchNormalization()(conv1) conv1 = Conv3D(64, (3, 3, 3), activation="relu", padding="same")(batch1) batch1 = BatchNormalization()(conv1) # Level 2 pool2 = MaxPooling3D((2, 2, 2))(batch1) conv2 = Conv3D(64, (3, 3, 3), activation="relu", padding="same")(pool2) batch2 = BatchNormalization()(conv2) conv2 = Conv3D(128, (3, 3, 3), activation="relu", padding="same")(batch2) batch2 = BatchNormalization()(conv2) # Level 3 pool3 = MaxPooling3D((2, 2, 2))(batch2) conv3 = Conv3D(128, (3, 3, 3), activation="relu", padding="same")(pool3) batch3 = BatchNormalization()(conv3) conv3 = Conv3D(256, (3, 3, 3), activation="relu", padding="same")(batch3) batch3 = BatchNormalization()(conv3) # Level 4 pool4 = MaxPooling3D((2, 2, 2))(batch3) conv4 = Conv3D(256, (3, 3, 3), activation="relu", padding="same")(pool4) batch4 = BatchNormalization()(conv4) conv4 = Conv3D(512, (3, 3, 3), activation="relu", padding="same")(batch4) batch4 = BatchNormalization()(conv4) # Level 3 up5 = Conv3DTranspose(512, (2, 2, 2), strides=(2, 2, 2), padding="same", activation="relu")(batch4) merge5 = concatenate([up5, batch3]) conv5 = Conv3D(256, (3, 3, 3), activation="relu")(merge5) batch5 = BatchNormalization()(conv5) conv5 = Conv3D(256, (3, 3, 3), activation="relu")(batch5) batch5 = BatchNormalization()(conv5) # Level 2 up6 = Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), activation="relu")(batch5) merge6 = concatenate([up6, Cropping3D(cropping=((4, 4), (4, 4), (4, 4)))(batch2)]) conv6 = Conv3D(128, (3, 3, 3), activation="relu")(merge6) batch6 = BatchNormalization()(conv6) conv6 = Conv3D(128, (3, 3, 3), activation="relu")(batch6) batch6 = BatchNormalization()(conv6) # Level 1 up7 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding="same", activation="relu")(batch6) merge7 = concatenate([up7, Cropping3D(cropping=((12, 12), (12, 12), (12, 12)))(batch1)]) conv7 = Conv3D(64, (3, 3, 3), activation="relu")(merge7) batch7 = BatchNormalization()(conv7) conv7 = Conv3D(64, (3, 3, 3), activation="relu")(batch7) batch7 = BatchNormalization()(conv7) # Output dim is (36, 36, 36) preds = Conv3D(1, (1, 1, 1), activation="sigmoid")(batch7) model = Model(inputs=input, outputs=preds) model.compile(optimizer=Adam(lr=0.001, decay=0.00), loss=weighted_binary_crossentropy, metrics=[axon_precision, axon_recall, f1_score, artifact_precision, edge_axon_precision, adjusted_accuracy]) return model
hidden.append(Activation('relu')(hidden[-1])) # hidden.append(tf.nn.batch_normalization(hidden[-1], bnm[len(bna)], bns[len(bna)], 0.0, 1.0, 0.000001)) # hidden.append((hidden[-1]-bnm[len(bna)])/bns[len(bna)]) # bna.append(hidden[-1]) # print('len bna',len(bna)) print('layer',len(hidden)-1,':',hidden[-1].shape,'after conv conv bn') print('...') # up for i in range(len(nFeatMapsList)-1): nFeatMaps = nFeatMapsList[-i-2] hidden.append(Conv3DTranspose(nFeatMaps,(3),strides=(2),padding='same',activation='relu')(hidden[-1])) print('layer',len(hidden)-1,':',hidden[-1].shape,'after upconv') toCrop = int((hidden[ccidx[-1-i]].shape[1]-hidden[-1].shape[1])//2) hidden.append(concatenate([hidden[-1], Cropping3D(toCrop)(hidden[ccidx[-1-i]])])) print('layer',len(hidden)-1,':',hidden[-1].shape,'after concat with cropped layer %d' % ccidx[-1-i]) # hidden.append(Dropout(0.5)(hidden[-1], training=t)) hidden.append(Conv3D(nFeatMaps,(3),padding='valid',activation=None)(hidden[-1])) hidden.append(tf.layers.batch_normalization(hidden[-1], training=t)) hidden.append(Conv3D(nFeatMaps,(3),padding='valid',activation=None)(hidden[-1])) hidden.append(tf.layers.batch_normalization(hidden[-1], training=t)) hidden.append(Activation('relu')(hidden[-1])) # hidden.append(tf.nn.batch_normalization(hidden[-1], bnm[len(bna)], bns[len(bna)], 0.0, 1.0, 0.000001)) # hidden.append((hidden[-1]-bnm[len(bna)])/bns[len(bna)]) # bna.append(hidden[-1]) # print('len bna',len(bna)) print('layer',len(hidden)-1,':',hidden[-1].shape,'after conv conv bn') print('...')
def create_generalized_unet_v2_model( number_input_features=4, subsample_factors_per_pathway=((1, 1, 1), (2, 2, 2), (4, 4, 4), (8, 8, 8), (16, 16, 16)), kernel_sizes_per_pathway=((((3, 3, 3), (3, 3, 3)), ((3, 3, 3), (3, 3, 3))), (((3, 3, 3), (3, 3, 3)), ((3, 3, 3), (3, 3, 3))), (((3, 3, 3), (3, 3, 3)), ((3, 3, 3), (3, 3, 3))), (((3, 3, 3), (3, 3, 3)), ((3, 3, 3), (3, 3, 3))), (((3, 3, 3), (3, 3, 3)), ())), number_features_per_pathway=(((30, 30), (30, 30)), ((60, 60), (60, 30)), ((120, 120), (120, 60)), ((240, 240), (240, 120)), ((480, 240), ())), kernel_sizes_common_pathway=((1, 1, 1), ) * 1, number_features_common_pathway=(1, ), dropout_common_pathway=(0, ), output_size=(128, 128, 128), metadata_sizes=None, metadata_number_features=None, metadata_dropout=None, metadata_at_common_pathway_layer=None, padding='same', pooling='max', upsampling='copy', activation='lrelu', activation_final_layer='sigmoid', kernel_initializer='he_normal', batch_normalization=False, batch_normalization_on_input=False, instance_normalization=True, instance_normalization_on_input=False, relaxed_normalization_scheme=False, mask_output=False, residual_connections=False, dense_connections=False, add_extra_dimension=False, l1_reg=0.0, l2_reg=1e-5, verbose=True, input_interpolation="nearest", number_siam_pathways=1, extra_output_kernel_sizes=None, extra_output_number_features=None, extra_output_dropout=None, extra_output_at_common_pathway_layer=None, extra_output_activation_final_layer=None, dynamic_input_shapes=False): from tensorflow.keras.layers import Input, Dropout, MaxPooling3D, Concatenate, Multiply, Add, Reshape, AveragePooling3D, Conv3D, UpSampling3D, Cropping3D, LeakyReLU, PReLU, BatchNormalization, Conv3DTranspose from tensorflow.keras import regularizers from tensorflow.keras import backend as K from tensorflow_addons.layers import InstanceNormalization from tensorflow.keras import Model # Define some in-house functions def normalization_function(): if batch_normalization_on_input or batch_normalization: normalization_function_ = BatchNormalization() elif instance_normalization_on_input or instance_normalization: normalization_function_ = InstanceNormalization() else: raise NotImplementedError return normalization_function_ def introduce_metadata(path, i): if dynamic_input_shapes: metadata_input_ = metadata_path = Input( shape=(None, None, None, metadata_sizes[i][-1] if isinstance( metadata_sizes[i], tuple) else metadata_sizes[i])) else: metadata_input_ = metadata_path = Input( shape=metadata_sizes[i] if isinstance(metadata_sizes[i], tuple) else (1, 1, 1, metadata_sizes[i])) metadata_inputs[i] = metadata_input_ for j, (m_n_f, m_d) in enumerate( zip(metadata_number_features[i], metadata_dropout[i])): metadata_path = Dropout(m_d)(metadata_path) metadata_path = Conv3D(filters=m_n_f, kernel_size=(1, 1, 1), padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg))(metadata_path) metadata_path = activation_function("m{}_activation{}".format( i, j))(metadata_path) if not isinstance(metadata_sizes[i], tuple): broadcast_shape = K.concatenate([ K.constant([1], dtype="int32"), K.shape(path)[1:-1], K.constant([1], dtype="int32") ]) metadata_path = K.tile(metadata_path, broadcast_shape) return Concatenate(axis=-1)([path, metadata_path]) def retrieve_extra_output(path, i): extra_output_path = path for j, (e_o_k_s, e_o_n_f, e_o_d) in enumerate( zip(extra_output_kernel_sizes[i], extra_output_number_features[i], extra_output_dropout[i])): extra_output_path = Dropout(e_o_d)(extra_output_path) extra_output_path = Conv3D( filters=e_o_n_f, kernel_size=e_o_k_s, activation=extra_output_activation_final_layer[i] if j + 1 == len(extra_output_number_features[i]) else None, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2(l1_reg, l2_reg), name=f"s{i + 1}" if j + 1 == len(extra_output_number_features[i]) else None)(extra_output_path) if j + 1 < len(extra_output_number_features[i]): if (batch_normalization or instance_normalization ) and not relaxed_normalization_scheme: extra_output_path = normalization_function()( extra_output_path) extra_output_path = activation_function( "eo{}_activation{}".format(i, j))(extra_output_path) return extra_output_path def activation_function(name): if activation == "relu": activation_function_ = LeakyReLU(alpha=0, name=name) elif activation == "lrelu": activation_function_ = LeakyReLU(alpha=0.01, name=name) elif activation == "prelu": activation_function_ = PReLU(shared_axes=[1, 2, 3], name=name) elif activation == "linear": def activation_function_(path): return path else: raise NotImplementedError return activation_function_ def pooling_function(pool_size): if pooling == "max": pooling_function_ = MaxPooling3D(pool_size, pool_size) elif pooling == "avg": pooling_function_ = AveragePooling3D(pool_size, pool_size) else: raise NotImplementedError return pooling_function_ def upsampling_function(upsample_size): if upsampling == "copy": upsampling_function_ = UpSampling3D(upsample_size) elif upsampling == "linear": def upsampling_function_(path): path = UpSampling3D(upsample_size)(path) path = AveragePooling3D(upsample_size, strides=(1, 1, 1), padding='valid')(path) return path elif upsampling == "conv": def upsampling_function_(path): path = Conv3DTranspose( K.int_shape(path)[-1], upsample_size, upsample_size)(path) return path else: raise NotImplementedError return upsampling_function_ # Define some in-house variables nb_pathways = len(subsample_factors_per_pathway) supported_activations = ["relu", "lrelu", "prelu", "linear"] supported_poolings = ["max", "avg"] supported_upsamplings = ["copy", "linear", "conv"] supported_paddings = ["valid", "same"] # Do some sanity checks if not len(kernel_sizes_per_pathway) == len( number_features_per_pathway) == nb_pathways: raise ValueError("Inconsistent number of pathways.") for p in range(nb_pathways - 1): if not [ f_next % f for f, f_next in zip(subsample_factors_per_pathway[p], subsample_factors_per_pathway[p + 1]) ] == [0, 0, 0]: raise ValueError( "Each subsample factor must be a integer factor of the subsample factor of the previous pathway." ) for k_s_p_p, n_f_p_p in zip(kernel_sizes_per_pathway, number_features_per_pathway): if not len(k_s_p_p) == len(n_f_p_p) == 2: raise ValueError( "Each element in kernel_sizes_per_pathway and number_features_per_pathway must be a list of two elements, giving information about the left (downwards) and right (upwards) paths of the U-Net." ) for k_s, n_f in zip(k_s_p_p, n_f_p_p): if not len(k_s) == len(n_f): raise ValueError( "Each kernel size in each element from kernel_sizes_per_pathway must correspond with a number of features in each element of number_features_per_pathway." ) if not len(kernel_sizes_common_pathway) == len( dropout_common_pathway) == len(number_features_common_pathway): raise ValueError("Inconsistent depth of common pathway.") if metadata_sizes in [None, []]: metadata_sizes = [] if metadata_number_features in [None, []]: metadata_number_features = [] else: raise ValueError( "Invalid value for metadata_number_features when there is no metadata" ) if metadata_dropout in [None, []]: metadata_dropout = [] else: raise ValueError( "Invalid value for metadata_dropout when there is no metadata") if metadata_at_common_pathway_layer in [None, []]: metadata_at_common_pathway_layer = [] else: raise ValueError( "Invalid value for metadata_at_common_pathway_layer when there is no metadata" ) else: if not len(metadata_sizes) == len(metadata_dropout) == len( metadata_number_features) == len( metadata_at_common_pathway_layer): raise ValueError("Inconsistent depth of metadata pathway.") if extra_output_number_features in [None, [], ()]: extra_output_number_features = () if extra_output_kernel_sizes in [None, [], ()]: extra_output_kernel_sizes = () else: raise ValueError( "Invalid value for extra_output_kernel_sizes when there is no extra_output" ) if extra_output_dropout in [None, [], ()]: extra_output_dropout = () else: raise ValueError( "Invalid value for extra_output_dropout when there is no extra_output" ) if extra_output_at_common_pathway_layer in [None, [], ()]: extra_output_at_common_pathway_layer = () else: raise ValueError( "Invalid value for extra_output_at_common_pathway_layer when there is no extra_output" ) if extra_output_activation_final_layer in [None, [], ()]: extra_output_activation_final_layer = () else: raise ValueError( "Invalid value for extra_output_activation_final_layer when there is no extra_output" ) if not len(extra_output_activation_final_layer) == len( extra_output_dropout) == len(extra_output_number_features) == len( extra_output_kernel_sizes) == len( extra_output_at_common_pathway_layer): raise ValueError("Inconsistent depth of extra_output pathway.") if residual_connections and dense_connections: raise ValueError( "Residual connections and Dense connections should not be used together." ) if dynamic_input_shapes and (residual_connections or dense_connections): raise NotImplementedError( "Currently using residual or dense connections are not supported when using dynamic input shapes!" ) if dense_connections and pooling != "avg": raise ValueError( "According to Huang et al. a densely connected network should have average pooling." ) if activation not in supported_activations: raise ValueError("The chosen activation is not supported.") if pooling not in supported_poolings: raise ValueError("The chosen pooling is not supported.") if upsampling not in supported_upsamplings: raise ValueError("The chosen upsampling is not supported.") if padding not in supported_paddings: raise ValueError("The chosen padding is not supported.") if (batch_normalization_on_input or batch_normalization) and (instance_normalization_on_input or instance_normalization): raise ValueError( "You have to choose between batch or instance normalization.") if relaxed_normalization_scheme and not (batch_normalization or instance_normalization): raise ValueError( "The relaxed normalization scheme can only be used if you also do (batch or instance) normalization." ) # Calculate the field of view field_of_view = np.ones(3, dtype=int) if input_interpolation == "mean": field_of_view *= np.array(subsample_factors_per_pathway[0]) for s_f_p_p, k_s_p_p in zip(subsample_factors_per_pathway, kernel_sizes_per_pathway): for k_s in k_s_p_p[0]: field_of_view += (np.array(k_s) - 1) * s_f_p_p for s_f_p_p, k_s_p_p in reversed( list(zip(subsample_factors_per_pathway, kernel_sizes_per_pathway))): for k_s in k_s_p_p[1]: field_of_view += (np.array(k_s) - 1) * s_f_p_p for k_s in kernel_sizes_common_pathway: field_of_view += (np.array(k_s) - 1) * subsample_factors_per_pathway[0] input_size = list(field_of_view - 1 + output_size) field_of_view = list(field_of_view) output_size = list(output_size) if verbose: print("\nfield of view:\t{}\t(theoretical)".format(field_of_view)) print("output size:\t{}\t(user defined)".format(output_size)) print( "input size:\t{}\t(inferred with theoretical field of view (less meaningful if padding='same'))" .format(input_size)) # What are the possible input and output sizes? path_left_output_positions = [] path_right_input_positions = [] path_left_output_sizes = [] path_right_input_sizes = [] input_sizes = output_sizes = np.stack([np.arange(500)] * 3, axis=-1) prev_s_f_p_p = np.ones(3) prev_positions = np.zeros(3) for i, (s_f_p_p, k_s_p_p) in enumerate( zip(((1, 1, 1), ) + subsample_factors_per_pathway, (((), ()), ) + kernel_sizes_per_pathway)): s_f = s_f_p_p // prev_s_f_p_p prev_s_f_p_p = np.array(s_f_p_p) prev_positions = np.abs(prev_positions - (1 - (s_f_p_p * (s_f - 1) + 1) % 2)) output_sizes = (output_sizes // s_f) * ((output_sizes % s_f) == 0) for k_s in k_s_p_p[0]: output_sizes -= np.array(k_s) - 1 if padding == 'valid' else 0 prev_positions = np.abs(prev_positions - (1 - (s_f_p_p * (np.array(k_s) - 1) + 1) % 2)) path_left_output_positions.append(prev_positions) path_left_output_sizes.append(output_sizes) prev_s_f_p_p = np.array(subsample_factors_per_pathway[-1]) for i, (s_f_p_p, k_s_p_p) in reversed( list( enumerate( zip(((1, 1, 1), (1, 1, 1)) + subsample_factors_per_pathway[:-1], (((), ()), ) + kernel_sizes_per_pathway)))): path_right_input_positions.append(prev_positions) path_right_input_sizes.append(output_sizes) output_sizes = output_sizes - np.abs(path_left_output_positions[i] - prev_positions) output_sizes *= ((path_left_output_sizes[i] - output_sizes) % 2) == 0 prev_positions = path_left_output_positions[i] for k_s in k_s_p_p[1]: output_sizes -= (np.array(k_s) - 1) if padding == 'valid' else 0 prev_positions = np.abs(prev_positions - (1 - (s_f_p_p * (np.array(k_s) - 1) + 1) % 2)) s_f = prev_s_f_p_p // s_f_p_p prev_s_f_p_p = np.array(s_f_p_p) output_sizes *= s_f output_sizes = output_sizes - s_f + 1 if upsampling == "linear" else output_sizes for k_s in kernel_sizes_common_pathway: output_sizes -= (np.array(k_s) - 1) if padding == 'valid' else 0 path_right_input_positions.reverse() possible_input_sizes = [ list(input_sizes[output_sizes[:, i] > 0, i]) for i in range(3) ] possible_output_sizes = [ list(output_sizes[output_sizes[:, i] > 0, i].astype('int')) for i in range(3) ] possible_path_left_output_sizes, possible_path_right_input_sizes, possible_crops = [], [], [] for i in range(3): possible_path_left_output_sizes_, possible_path_right_input_sizes_, possible_crops_ = [], [], [] for j, o_s in enumerate(output_sizes[:, i]): if o_s > 0: possible_path_left_output_sizes_.append([ path_left_output_sizes_[j, i] for path_left_output_sizes_ in path_left_output_sizes ]) possible_path_right_input_sizes_.append([ path_right_input_sizes_[j, i] for path_right_input_sizes_ in reversed(path_right_input_sizes) ]) possible_crops_.append([ int((pplos - ppris) / 2) for pplos, ppris in zip( possible_path_left_output_sizes_[-1], possible_path_right_input_sizes_[-1]) ]) possible_path_left_output_sizes.append( possible_path_left_output_sizes_) possible_path_right_input_sizes.append( possible_path_right_input_sizes_) possible_crops.append(possible_crops_) if not [ o_s in p_o_s for o_s, p_o_s in zip(output_size, possible_output_sizes) ] == [True] * 3: print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format( *possible_output_sizes)) print( "\npossible input sizes (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}" .format(*possible_input_sizes)) raise ValueError( "The user defined output_size is not possible. Please choose from list above." ) elif verbose: input_size = [ possible_input_sizes[i][possible_output_sizes[i].index(o_s)] for i, o_s in enumerate(output_size) ] print("input size:\t{}\t(true input size of the network)".format( input_size)) print("\npossible output sizes:\nx: {}\ny: {}\nz: {}".format( *possible_output_sizes)) print( "\npossible input sizes (corresponding with the possible output sizes):\nx: {}\ny: {}\nz: {}" .format(*possible_input_sizes)) crops = [ possible_crops[i][possible_output_sizes[i].index(output_size[i])] for i in range(3) ] indices_with_identical_cropping = [] for i in range(3): indices_with_identical_cropping_ = [] for j, crops_ in enumerate(possible_crops[i]): if all( [crop == crop_ for crop, crop_ in zip(crops[i], crops_)]): indices_with_identical_cropping_.append(j) indices_with_identical_cropping.append( indices_with_identical_cropping_) print( "\npossible output sizes when using dynamic shapes (based on output size {}):\nx: {}\ny: {}\nz: {}" .format( output_size, *[[ possible_output_sizes[i][j] for j in indices_with_identical_cropping[i] ] for i in range(3)])) print( "\npossible input sizes when using dynamic shapes (based on output size {}):\nx: {}\ny: {}\nz: {}" .format( output_size, *[[ possible_input_sizes[i][j] for j in indices_with_identical_cropping[i] ] for i in range(3)])) # Construct model inputs = [] paths = [] metadata_inputs = [None] * len(metadata_at_common_pathway_layer) # 1. Construct U part if dynamic_input_shapes: input_ = path = Input(shape=(None, None, None, number_input_features), name="siam0_input") else: input_ = path = Input(shape=list(input_size) + [number_input_features], name="siam0_input") inputs.append(input_) path_left_output_paths = [] path_right_output_paths = [] # Downwards prev_s_f_p_p = np.ones(3) for i, (s_f_p_p, k_s_p_p, n_f_p_p) in enumerate( zip(((1, 1, 1), ) + subsample_factors_per_pathway, (((), ()), ) + kernel_sizes_per_pathway, (((), ()), ) + number_features_per_pathway)): s_f = s_f_p_p // prev_s_f_p_p prev_s_f_p_p = np.array(s_f_p_p) path = pooling_function(s_f)(path) if i > 0: if i == 1 and (batch_normalization_on_input or instance_normalization_on_input): path = normalization_function()(path) for j, (k_s, n_f) in enumerate(zip(k_s_p_p[0], n_f_p_p[0])): if j == 0 and (residual_connections or dense_connections): shortcut = path elif (0 < j < len(k_s_p_p[0]) - 1) and dense_connections: shortcut = Cropping3D([ int((l - r) / 2) for l, r in zip( K.int_shape(shortcut)[1:-1], K.int_shape(path)[1:-1]) ])(shortcut) shortcut = Concatenate(axis=-1)([path, shortcut]) path = shortcut if ((i > 1 or j > 0) and (batch_normalization or instance_normalization)) and ( not relaxed_normalization_scheme or j == len(k_s_p_p[0]) - 1): path = normalization_function()(path) if i > 1 or j > 0: path = activation_function("down_p{}_activation{}".format( i, j))(path) path = Conv3D(filters=n_f, kernel_size=k_s, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg), name="{}_{}".format(i, j))(path) if j == len(k_s_p_p[0]) - 1 and residual_connections: shortcut = Cropping3D([ int((l - r) / 2) for l, r in zip( K.int_shape(shortcut)[1:-1], K.int_shape(path)[1:-1]) ])(shortcut) if K.int_shape(path)[-1] != K.int_shape(shortcut)[-1]: shortcut = Conv3D( filters=K.int_shape(path)[-1], kernel_size=(1, 1, 1), padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg))(shortcut) path = Add()([path, shortcut]) path_left_output_paths.append(path) paths.append(path) # Metadata for i, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer): if m_a_c_p_l == "x": assert number_siam_pathways == 1, "Insertion of metadata at 'x' is currently not supported with the use of siam pathways." path = introduce_metadata(path, i) # Upwards prev_s_f_p_p = np.array(subsample_factors_per_pathway[-1]) for i, (s_f_p_p, k_s_p_p, n_f_p_p) in reversed( list( enumerate( zip(((1, 1, 1), (1, 1, 1)) + subsample_factors_per_pathway[:-1], (((), ()), ) + kernel_sizes_per_pathway, (((), ()), ) + number_features_per_pathway)))): path = AveragePooling3D(np.abs(path_left_output_positions[i] - path_right_input_positions[i]) + 1, strides=(1, 1, 1))(path) if i > 0: if i < nb_pathways: path_left = Cropping3D([crops_[i] for crops_ in crops ])(path_left_output_paths[i]) path = Concatenate(axis=-1)([path_left, path]) for j, (k_s, n_f) in enumerate(zip(k_s_p_p[1], n_f_p_p[1])): if j == 0 and (residual_connections or dense_connections): shortcut = path elif (0 < j < len(k_s_p_p[1]) - 1) and dense_connections: shortcut = Cropping3D([ int((l - r) / 2) for l, r in zip( K.int_shape(shortcut)[1:-1], K.int_shape(path)[1:-1]) ])(shortcut) shortcut = Concatenate(axis=-1)([path, shortcut]) path = shortcut if (batch_normalization or instance_normalization) and ( not relaxed_normalization_scheme or j == len(k_s_p_p[1]) - 1): path = normalization_function()(path) path = activation_function("up_p{}_activation{}".format( i, j))(path) path = Conv3D(filters=n_f, kernel_size=k_s, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg))(path) if j == len(k_s_p_p[1]) - 1 and residual_connections: shortcut = Cropping3D([ int((l - r) / 2) for l, r in zip( K.int_shape(shortcut)[1:-1], K.int_shape(path)[1:-1]) ])(shortcut) if K.int_shape(path)[-1] != K.int_shape(shortcut)[-1]: shortcut = Conv3D( filters=K.int_shape(path)[-1], kernel_size=(1, 1, 1), padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2( l1_reg, l2_reg))(shortcut) path = Add()([path, shortcut]) path_right_output_paths.append(path) s_f = prev_s_f_p_p // s_f_p_p prev_s_f_p_p = np.array(s_f_p_p) path = upsampling_function(s_f)(path) paths.append(path) # SIAMESE networks if number_siam_pathways > 1: outputs = [path] siam_model = Model(inputs=inputs, outputs=path) for i in range(number_siam_pathways - 1): inputs_ = [] if dynamic_input_shapes: input__ = Input(shape=(None, None, None, number_input_features), name=f"siam{i + 1}_input") else: input_ = Input(shape=list(input_size) + [number_input_features], name=f"siam{i + 1}_input") inputs_.append(input_) inputs.append(input_) for j, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer): if m_a_c_p_l == "x": meta_input_ = Input(shape=(1, 1, 1, metadata_sizes[j]), name=f"siam{i + 1}_meta{j + 1}") inputs_.append(meta_input_) inputs.append(meta_input_) outputs.append( siam_model(inputs_ if len(inputs_) > 1 else inputs_[0])) path = Concatenate(axis=-1)(outputs) outputs = [None] * len(extra_output_at_common_pathway_layer) # 2. Construct common pathway for i, (n_f_c_p, k_s_c_p, d_c_p) in enumerate( zip(number_features_common_pathway, kernel_sizes_common_pathway, dropout_common_pathway)): if (batch_normalization or instance_normalization) and ( i == 0 or not relaxed_normalization_scheme): path = normalization_function()(path) path = activation_function("c_activation{}".format(i))(path) for j, m_a_c_p_l in enumerate(metadata_at_common_pathway_layer): if m_a_c_p_l == i: path = introduce_metadata(path, j) for j, e_o_a_c_p_l in enumerate(extra_output_at_common_pathway_layer): if e_o_a_c_p_l == i: outputs[j] = retrieve_extra_output(path, j) if d_c_p: path = Dropout(d_c_p)(path) path = Conv3D(filters=n_f_c_p, kernel_size=k_s_c_p, activation=activation_final_layer if i + 1 == len(number_features_common_pathway) else None, padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=regularizers.l1_l2(l1_reg, l2_reg), name="s0" if i + 1 == len(number_features_common_pathway) else None)(path) inputs = inputs + metadata_inputs outputs.insert(0, path) # 3. Mask the output (optionally) if mask_output: if dynamic_input_shapes: mask_input_ = mask_path = Input(shape=(None, None, None, K.int_shape(path)[-1])) else: mask_input_ = mask_path = Input(shape=tuple(output_size) + (K.int_shape(path)[-1], )) inputs.append(mask_input_) for i, output in outputs: outputs[i] = Multiply()([output, mask_path]) # 4. For example: Correct for segment sampling changes to P(X|Y) --> this adds an extra dimension because the correction is done inside loss function and weights are given with y_creator in extra dimension (can only be done for binary like this) if add_extra_dimension: for i, output in outputs: outputs[i] = K.expand_dims(output, -1) model = Model(inputs=inputs, outputs=outputs) # Final sanity check: were our calculations correct? if verbose: print("\nNetwork summary:") print(model.summary()) model_input_shape = model.input_shape if not isinstance(model_input_shape, list): model_input_shape = [model_input_shape] if not dynamic_input_shapes: assert list(model_input_shape[0][1:-1]) == input_size print('With a batch size of {} this model needs {} GB on the GPU.'. format(1, get_model_memory_usage(1, model))) else: print( "Since you are using dynamic_input_shapes=True we cannot calculate the memory usage, nor can we truely check the correctness of the shapes..." ) return model
def model(x, N_svd=0): up_scale = 4 with tf.variable_scope('ASR'): input_shape = x.get_shape().as_list() chn_in = input_shape[4] chn_base = 6 * up_scale # shape is [batch, 6, 24, 64, 1] # Group 1 h = Conv3D(filters=chn_base, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv1_1')(x) h = Conv3D(filters=chn_base, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv1_2')(h) h1 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1), 'SAME', activation='relu', name='deconv1')(h) h1 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h1) h1 = Conv3D(filters=chn_base, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv1_3')(h1) # shape is [batch, 6, 24, 64, chn_base] h = Conv3D(filters=chn_base * 2, kernel_size=(1, 3, 3), strides=(1, 2, 2), activation='relu', padding='SAME', name='conv1_4')(h) # shape is [batch, 6, 12, 32, chn_base * 2] # Group 2 h = Conv3D(filters=chn_base * 2, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv2_1')(h) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv2_2')(h) h2 = Conv3DTranspose(chn_base, (7, 1, 3), (up_scale, 1, 1), 'SAME', activation='relu', name='deconv2')(h) h2 = Cropping3D(cropping=((0, 3), (0, 0), (0, 0)))(h2) h2 = Conv3D(filters=chn_base * 2, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv2_3')(h2) # shape is [batch, 6, 12, 32, chn_base * 2] h = Conv3D(filters=chn_base * 4, kernel_size=(1, 1, 3), strides=(1, 1, 2), activation='relu', padding='SAME', name='conv2_4')(h) # shape is [batch, 6, 12, 16, chn_base * 4] # Layer 3, shrinking h = Conv3D(filters=chn_base * 2, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv3')(h) # shape is [batch, 6, 12, 16, chn_base * 2] # Group 4, Mapping for i in range(2): h = Conv3D(filters=chn_base * 2, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv4_1_' + str(i))(h) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv4_2_' + str(i))(h) # shape is [batch, 6, 12, 16, chn_base * 2] # Group 5, Attention h = SAAM(h, up_scale=up_scale, N_svd=N_svd) # Layer 6, Expanding h = Conv3D(filters=chn_base * 4, kernel_size=(1, 1, 1), activation='relu', padding='SAME', name='conv6')(h) # shape is [batch, 6, 12, 16, chn_base * 4] # Group 7 h = Conv3DTranspose(chn_base * 2, (1, 1, 4), (1, 1, 2), activation='relu', padding='SAME', name='conv7_1')(h) # shape is [batch, 16, 12, 32, chn_base * 2] h = tf.concat([h, h2], axis=-1) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv7_2')(h) h = Conv3D(filters=chn_base * 2, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv7_3')(h) # shape is [batch, 16, 12, 32, chn_base * 2] # Group 8 h = Conv3DTranspose(chn_base, (1, 4, 4), (1, 2, 2), activation='relu', padding='SAME', name='conv8_1')(h) # shape is [batch, 16, 24, 64, chn_base] h = tf.concat([h, h1], axis=-1) h = Conv3D(filters=chn_base, kernel_size=(3, 1, 3), activation='relu', padding='SAME', name='conv8_2')(h) h = Conv3D(filters=chn_base, kernel_size=(3, 3, 1), activation='relu', padding='SAME', name='conv8_3')(h) # shape is [batch, 16, 24, 64, chn_base] # Group 9 h = Conv3D(filters=chn_in, kernel_size=(3, 3, 3), padding='SAME', name='conv9')(h) return h
def conv3d(filters, latentDim, path, batch=False, dropout=False, filter_size=3): checkpoint_dir = os.path.dirname(path) model = Sequential() #encoder #input = 28 x 28 x 1 (wide and thin) model.add( InputLayer(input_shape=(config.NUM_CHANNELS, config.IMG_HEIGHT, config.IMG_WIDTH, 1))) for f in filters: model.add( Conv3D(f, filter_size, strides=2, activation='relu', padding="same")) if (dropout): model.add(Dropout(0.2)) if (batch): model.add(BatchNormalization()) # model.add(MaxPool3D(pool_size=(1,2,2))) if (latentDim is not None): # model.add(Flatten()) model.add( Conv3D(latentDim, 1, strides=1, activation='relu', padding="same")) if (batch): model.add(BatchNormalization()) # model.add(Flatten()) # model.add(Dense(latentDim)) for f in reversed(filters): # apply a CONV_TRANSPOSE => RELU => BN operation model.add( Conv3DTranspose(f, filter_size, activation='relu', strides=2, padding="same")) if (dropout): model.add(Dropout(0.2)) if (batch): model.add(BatchNormalization()) model.add(Conv3D(1, filter_size, activation='sigmoid', padding='same')) if (config.NUM_CHANNELS % (2**len(filters)) != 0): dim = config.NUM_CHANNELS for i in range(len(filters)): if (dim % 2 != 0): dim = int(dim / 2) dim += 1 else: dim = int(dim / 2) print(dim) croppingFactor = int( (dim * (2**len(filters)) - config.NUM_CHANNELS) / 2) model.add( Cropping3D(cropping=((croppingFactor, croppingFactor), (0, 0), (0, 0)))) model.summary() model.compile(loss='mean_squared_error', optimizer=Adam()) cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=path, # monitor='val_loss', save_weights_only=True, # save_best_only=True, verbose=1, save_freq='epoch') latest = tf.train.latest_checkpoint(checkpoint_dir) # print('latestdasdasdas') print(latest) if latest is not None: model.load_weights(latest) print('weights loaded') return model, cp_callback
print('layer', len(hidden) - 1, ':', hidden[-1].shape, 'after conv conv bn') print('...') # up for i in range(len(nFeatMapsList) - 1): nFeatMaps = nFeatMapsList[-i - 2] hidden.append( Conv3DTranspose(nFeatMaps, (3), strides=(2), padding='same', activation='relu')(hidden[-1])) print('layer', len(hidden) - 1, ':', hidden[-1].shape, 'after upconv') toCrop = int((hidden[ccidx[-1 - i]].shape[1] - hidden[-1].shape[1]) // 2) hidden.append( concatenate([hidden[-1], Cropping3D(toCrop)(hidden[ccidx[-1 - i]])])) print('layer', len(hidden) - 1, ':', hidden[-1].shape, 'after concat with cropped layer %d' % ccidx[-1 - i]) # hidden.append(Dropout(0.5)(hidden[-1], training=t)) hidden.append( Conv3D(nFeatMaps, (3), padding='valid', activation=None)(hidden[-1])) hidden.append( tf.compat.v1.layers.batch_normalization(hidden[-1], training=t)) hidden.append( Conv3D(nFeatMaps, (3), padding='valid', activation=None)(hidden[-1])) hidden.append( tf.compat.v1.layers.batch_normalization(hidden[-1], training=t)) hidden.append(Activation('relu')(hidden[-1]))
def apetnet_vv5_onnx(input_tensor=None, n_ind_layers=1, n_common_layers=7, n_kernels_ind=15, n_kernels_common=30, kernel_shape=(3, 3, 3), add_final_relu=False, debug=False): """ Stacked single channel version of apetnet For description of input parameters see apetnet The input_tensor argument is only used determine the input shape. If None the input shape us set to (32,16,16,1). """ # define input (stacked PET and MRI image) if input_tensor is not None: ipt = Input(input_tensor.shape[1:5], name='input') else: ipt = Input(shape=(32, 16, 16, 1), name='input') # extract pet and mri image # - first image in order is pet ipt_dim_crop = int(ipt.shape[1] // 2) mri_image = Cropping3D(cropping=((ipt_dim_crop, 0), (0, 0), (0, 0)), name='extract_mri')(ipt) pet_image = Cropping3D(cropping=((0, ipt_dim_crop), (0, 0), (0, 0)), name='extract_pet')(ipt) # create the full model if not debug: # individual paths if n_ind_layers > 0: init_val_ind = RandomNormal( mean=0.0, stddev=np.sqrt(2 / (np.prod(kernel_shape) * n_kernels_ind))) pet_image_ind = pet_image mri_image_ind = mri_image for i in range(n_ind_layers): pet_image_ind = Conv3D( n_kernels_ind, kernel_shape, padding='same', name='conv3d_pet_ind_' + str(i), kernel_initializer=init_val_ind)(pet_image_ind) pet_image_ind = PReLU(shared_axes=[1, 2, 3], name='prelu_pet_ind_' + str(i))(pet_image_ind) mri_image_ind = Conv3D( n_kernels_ind, kernel_shape, padding='same', name='conv3d_mri_ind_' + str(i), kernel_initializer=init_val_ind)(mri_image_ind) mri_image_ind = PReLU(shared_axes=[1, 2, 3], name='prelu_mri_ind_' + str(i))(mri_image_ind) # concatenate inputs net = Concatenate(name='concat_0')([pet_image_ind, mri_image_ind]) else: # concatenate inputs net = Concatenate(name='concat_0')([pet_image, mri_image]) # common path init_val_common = RandomNormal( mean=0.0, stddev=np.sqrt(2 / (np.prod(kernel_shape) * n_kernels_common))) for i in range(n_common_layers): net = Conv3D(n_kernels_common, kernel_shape, padding='same', name='conv3d_' + str(i), kernel_initializer=init_val_common)(net) net = PReLU(shared_axes=[1, 2, 3], name='prelu_' + str(i))(net) # layers that adds all features net = Conv3D(1, (1, 1, 1), padding='valid', name='conv_final', kernel_initializer=RandomNormal(mean=0.0, stddev=np.sqrt(2)))(net) # add pet_image to prediction net = Add(name='add_0')([net, pet_image]) # ensure that output is non-negative if add_final_relu: net = ReLU(name='final_relu')(net) # in debug mode only add up pet and mri image else: net = Concatenate(name='add_0')([pet_image, mri_image]) # create model model = Model(inputs=ipt, outputs=net) # return the model return model