def skip_merge(skip_layers, upsampled_layers, skip_merge_type, data_format, num_dims, padding): """Skip connection concatenate/add to upsampled layer :param keras.layer skip_layers: as named :param keras.layer upsampled_layers: as named :param str skip_merge_type: [add, concat] :param str data_format: [channels_first, channels_last] :param int num_dims: as named :param str padding: same or valid :return: keras.layer skip merged layer """ channel_axis = get_channel_axis(data_format) # crop input if padding='valid' if padding == 'valid': skip_layers = Lambda(_crop_layer, arguments={ 'final_layer': upsampled_layers, 'data_format': data_format, 'num_dims': num_dims })(skip_layers) if skip_merge_type == 'concat': layer = Concatenate(axis=channel_axis)([upsampled_layers, skip_layers]) else: skip_layers = Lambda(pad_channels, arguments={ 'final_layer': upsampled_layers, 'channel_axis': channel_axis })(skip_layers) layer = Add()([upsampled_layers, skip_layers]) return layer
def kl_divergence_loss(y_true, y_pred): """KL divergence loss D(y||y') = sum(p(y)*log(p(y)/p(y')) :param y_true: Ground truth :param y_pred: Prediction :return float: KL divergence loss """ y_true = K.clip(y_true, K.epsilon(), 1) y_pred = K.clip(y_pred, K.epsilon(), 1) channel_axis = get_channel_axis(K.image_data_format()) return K.sum(y_true * K.log(y_true / y_pred), axis=channel_axis)
def mse_loss(y_true, y_pred, mean_loss=True): """Mean squared loss :param y_true: Ground truth :param y_pred: Prediction :return float: Mean squared error loss """ if not mean_loss: return K.square(y_pred - y_true) channel_axis = get_channel_axis(K.image_data_format()) return K.mean(K.square(y_pred - y_true), axis=channel_axis)
def mae_loss(y_true, y_pred, mean_loss=True): """Mean absolute error Keras losses by default calculate metrics along axis=-1, which works with image_format='channels_last'. The arrays do not seem to batch flattened, change axis if using 'channels_first """ if not mean_loss: return K.abs(y_pred - y_true) channel_axis = get_channel_axis(K.image_data_format()) return K.mean(K.abs(y_pred - y_true), axis=channel_axis)
def binary_crossentropy_loss(y_true, y_pred, mean_loss=True): """Binary cross entropy loss :param y_true: Ground truth :param y_pred: Prediction :return float: Binary cross entropy loss """ assert len(np.unique(y_true).tolist()) <= 2 assert len(np.unique(y_pred).tolist()) <= 2 if not mean_loss: return K.binary_crossentropy(y_true, y_pred) channel_axis = get_channel_axis(K.image_data_format()) return K.mean(K.binary_crossentropy(y_true, y_pred), axis=channel_axis)
def _crop_layer(input_layer, final_layer, data_format, num_dims): """Crop input layer to match shape of final layer ONLY SYMMETRIC CROPPING IS HANDLED HERE! :param keras.layers final_layer: last layer of conv block or skip layers in Unet :param keras.layers input_layer: input_layer to the block :param str data_format: [channels_first, channels_last] :param int num_dims: as named :return: keras.layer, input layer cropped if shape is different than final layer, else input layer as is """ input_shape = tf.shape(input_layer) final_shape = tf.shape(final_layer) # offsets for the top left corner of the crop if data_format == 'channels_first': offsets = [ 0, 0, (input_shape[2] - final_shape[2]) // 2, (input_shape[3] - final_shape[3]) // 2 ] crop_shape = [-1, input_shape[1], final_shape[2], final_shape[3]] if num_dims == 3: offsets.append((input_shape[4] - final_shape[4]) // 2) crop_shape.append(final_shape[4]) else: offsets = [ 0, (input_shape[1] - final_shape[1]) // 2, (input_shape[2] - final_shape[2]) // 2 ] crop_shape = [-1, final_shape[1], final_shape[2]] if num_dims == 3: offsets.append((input_shape[3] - final_shape[3]) // 2) crop_shape.append(final_shape[3]) offsets.append(0) crop_shape.append(input_shape[-1]) # https://github.com/tensorflow/tensorflow/issues/19376 input_cropped = tf.slice(input_layer, offsets, crop_shape) op_shape = final_layer.get_shape().as_list() channel_axis = get_channel_axis(data_format) op_shape[channel_axis] = input_layer.get_shape().as_list()[channel_axis] input_cropped.set_shape(tuple(op_shape)) return input_cropped
def test_pad_channels(self): """Test pad_channels() zero-pads the layer along the channel dimension when padding=same. zero-pads + crops when padding=valid """ for idx, in_shape in enumerate([self.in_shape_2d, self.in_shape_3d]): # create a model that gives padded layer as output self.network_config['num_dims'] = \ self.network_config['num_dims'] + idx in_layer = k_layers.Input(shape=in_shape, dtype='float32') conv_layer = get_keras_layer('conv', self.network_config['num_dims']) out_layer = conv_layer( filters=self.network_config['num_filters_per_block'][0], kernel_size=self.network_config['filter_size'], padding='same', data_format=self.network_config['data_format'])(in_layer) channel_axis = get_channel_axis(self.network_config['data_format']) layer_padded = k_layers.Lambda(conv_blocks.pad_channels, arguments={ 'final_layer': out_layer, 'channel_axis': channel_axis })(in_layer) # layer padded has zeros in all channels except 8 model = Model(in_layer, layer_padded) test_shape = list(in_shape) test_shape.insert(0, 1) test_image = np.ones(shape=test_shape) sess = K.get_session() # forward pass out = model.predict(test_image, batch_size=1) # test shape: should be the same as conv_layer out_shape = list(in_shape) out_shape[0] = self.network_config['num_filters_per_block'][0] np.testing.assert_array_equal(out_layer.get_shape().as_list()[1:], out_shape) out = np.squeeze(out) # only slice 8 is not zero nose.tools.assert_equal(np.sum(out), np.sum(out[8])) np.testing.assert_array_equal(out[8], np.squeeze(test_image)) nose.tools.assert_equal(np.sum(out[8]), np.prod(in_shape))
def downsample_conv_block(layer, network_config, block_idx, downsample_shape=None): """Conv-BN-activation block :param keras.layers layer: current input layer :param dict network_config: please check conv_block() :param int block_idx: block index in the network :param tuple downsample_shape: anisotropic downsampling kernel shape :return: keras.layers after downsampling and conv_block """ conv = get_keras_layer(type='conv', num_dims=network_config['num_dims']) block_sequence = network_config['block_sequence'].split('-') for conv_idx in range(network_config['num_convs_per_block']): for cur_layer_type in block_sequence: if cur_layer_type == 'conv': if block_idx > 0 and conv_idx == 0: if downsample_shape is None: stride = (2, ) * network_config['num_dims'] else: stride = downsample_shape else: stride = (1, ) * network_config['num_dims'] layer = conv( filters=network_config['num_filters_per_block'][block_idx], kernel_size=network_config['filter_size'], strides=stride, padding=network_config['padding'], kernel_initializer=network_config['init'], data_format=network_config['data_format'])(layer) elif cur_layer_type == 'bn' and network_config['batch_norm']: layer = BatchNormalization(axis=get_channel_axis( network_config['data_format']))(layer) else: activation_layer_instance = create_activation_layer( network_config['activation']) layer = activation_layer_instance(layer) if network_config['dropout']: layer = Dropout(network_config['dropout'])(layer) return layer
def _split_ytrue_mask(y_true, n_channels): """Split the mask concatenated with y_true :param keras.tensor y_true: if channels_first, ytrue has shape [batch_size, n_channels, y, x]. mask is concatenated as the n_channels+1, shape: [[batch_size, n_channels+1, y, x]. :param int n_channels: number of channels in y_true :return: keras.tensor ytrue_split - ytrue with the mask removed keras.tensor mask_image - bool mask """ try: split_axis = get_channel_axis(K.image_data_format()) y_true_split, mask_image = tf.split(y_true, [n_channels, 1], axis=split_axis) return y_true_split, mask_image except Exception as e: print('cannot separate mask and y_true' + str(e))
def _merge_residual(final_layer, input_layer, data_format, num_dims, kernel_init, padding): """Add residual connection from input to last layer :param keras.layers final_layer: last layer :param keras.layers input_layer: input_layer :param str data_format: [channels_first, channels_last] :param int num_dims: as named :param str kernel_init: kernel initializer from config :param str padding: same or valid :return: input_layer 1x1 / padded to match the shape of final_layer and added """ channel_axis = get_channel_axis(data_format) conv_object = get_keras_layer(type='conv', num_dims=num_dims) num_final_layers = int(final_layer.get_shape()[channel_axis]) num_input_layers = int(input_layer.get_shape()[channel_axis]) # crop input if padding='valid' if padding == 'valid': input_layer = Lambda(_crop_layer, arguments={ 'final_layer': final_layer, 'data_format': data_format, 'num_dims': num_dims })(input_layer) if num_input_layers > num_final_layers: # use 1x 1 to get to the desired num of feature maps input_layer = conv_object(filters=num_final_layers, kernel_size=(1, ) * num_dims, padding='same', kernel_initializer=kernel_init, data_format=data_format)(input_layer) elif num_input_layers < num_final_layers: # padding with zeros along channels input_layer = Lambda(pad_channels, arguments={ 'final_layer': final_layer, 'channel_axis': channel_axis })(input_layer) layer = Add()([final_layer, input_layer]) return layer
def conv_block(layer, network_config, block_idx): """Convolution block Allowed block-seq: [conv-BN-activation, conv-activation-BN, BN-activation-conv] To accommodate params of advanced activations, activation is a dict with keys 'type' and 'params'. For a complete list of keys in network_config, refer to BaseConvNet.__init__() in base_conv_net.py :param keras.layers layer: current input layer :param dict network_config: dict with network related keys :param int block_idx: block index in the network :return: keras.layers after performing operations in block-sequence repeated for num_convs_per_block times TODO: data_format from network_config won't work for full 3D models in predict if depth is set to None """ conv = get_keras_layer(type='conv', num_dims=network_config['num_dims']) block_sequence = network_config['block_sequence'].split('-') for _ in range(network_config['num_convs_per_block']): for cur_layer_type in block_sequence: if cur_layer_type == 'conv': layer = conv( filters=network_config['num_filters_per_block'][block_idx], kernel_size=network_config['filter_size'], padding=network_config['padding'], kernel_initializer=network_config['init'], data_format=network_config['data_format'])(layer) elif cur_layer_type == 'bn' and network_config['batch_norm']: layer = BatchNormalization(axis=get_channel_axis( network_config['data_format']))(layer) else: activation_layer_instance = create_activation_layer( network_config['activation']) layer = activation_layer_instance(layer) if network_config['dropout']: layer = Dropout(network_config['dropout'])(layer) return layer
def test_get_channel_axis_first(): channel_axis = aux_utils.get_channel_axis('channels_first') nose.tools.assert_equal(channel_axis, 1)