def assemble_model_multi_slice(ms_ndim_out=3, **model_kwargs): assert (model_kwargs['ndim'] == 3) # Assemble base model. input_shape = model_kwargs['input_shape'] input_shape_2D = (input_shape[0], ) + input_shape[2:] model_kwargs_2D = copy.copy(model_kwargs) model_kwargs_2D['ndim'] = 2 model_kwargs_2D['input_shape'] = input_shape_2D base_model = assemble_base_model(**model_kwargs_2D) # Instantiate parallel models, sharing weights. # NOTE: batch norm statistics are shared! input_multi_slice = Input(input_shape) lesion_output_pre = [] liver_output_pre = [] z_axis = 2 def select(i): return Lambda(lambda x: x[:, :, i, :, :], output_shape=input_shape_2D) def expand(): output_shape = ( model_kwargs['input_num_filters'], 1, ) + input_shape_2D[1:] return Lambda(lambda x: K.expand_dims(x, axis=z_axis), output_shape=output_shape) for i in range(3): out_0, out_1 = base_model(select(i)(input_multi_slice)) lesion_output_pre.append(expand()(out_0)) liver_output_pre.append(expand()(out_1)) lesion_output_pre = merge_concatenate(lesion_output_pre, axis=z_axis) liver_output_pre = merge_concatenate(liver_output_pre, axis=z_axis) if ms_ndim_out == 2: flat_shape = (model_kwargs['input_num_filters']*3,)\ +input_shape_2D[1:] lesion_output_pre = Reshape(flat_shape)(lesion_output_pre) liver_output_pre = Reshape(flat_shape)(liver_output_pre) # Add convolutions to combine information across slices. nonlinearity = model_kwargs['nonlinearity'] lesion_output_pre = Convolution( \ filters=model_kwargs['input_num_filters'], kernel_size=3, ndim=ms_ndim_out, padding='same', weight_norm=model_kwargs['weight_norm'], kernel_regularizer=_l2(model_kwargs['weight_decay']), name='conv_3D_0')(lesion_output_pre) lesion_output_pre = get_nonlinearity(nonlinearity)(lesion_output_pre) liver_output_pre = Convolution( \ filters=model_kwargs['input_num_filters'], kernel_size=3, ndim=ms_ndim_out, padding='same', weight_norm=model_kwargs['weight_norm'], kernel_regularizer=_l2(model_kwargs['weight_decay']), name='conv_3D_1')(liver_output_pre) liver_output_pre = get_nonlinearity(nonlinearity)(liver_output_pre) # Create classifier for lesion. if model_kwargs['num_classes'] is not None: lesion_output = Convolution(filters=1, kernel_size=1, ndim=ms_ndim_out, activation='linear', kernel_regularizer=_l2( model_kwargs['weight_decay']), name='classifier_conv_0') lesion_output = lesion_output(lesion_output_pre) if ms_ndim_out == 2: lesion_output = Permute((2, 3, 1))(lesion_output) else: lesion_output = Permute((2, 3, 4, 1))(lesion_output) lesion_output = Activation('sigmoid', name='sigmoid_0')(lesion_output) if ms_ndim_out == 2: lesion_output_layer = Permute((3, 1, 2)) else: lesion_output_layer = Permute((4, 1, 2, 3)) lesion_output_layer.name = 'output_0' lesion_output = lesion_output_layer(lesion_output) else: lesion_output = Activation('linear', name='output_0')(lesion_output_pre) # Create classifier for liver. if model_kwargs['num_classes'] is not None: liver_output = Convolution(filters=1, kernel_size=1, ndim=ms_ndim_out, activation='linear', kernel_regularizer=_l2( model_kwargs['weight_decay']), name='classifier_conv_1') liver_output = liver_output(liver_output_pre) if ms_ndim_out == 2: liver_output = Permute((2, 3, 1))(liver_output) else: liver_output = Permute((2, 3, 4, 1))(liver_output) liver_output = Activation('sigmoid', name='sigmoid_1')(liver_output) if ms_ndim_out == 2: liver_output_layer = Permute((3, 1, 2)) else: liver_output_layer = Permute((4, 1, 2, 3)) liver_output_layer.name = 'output_1' liver_output = liver_output_layer(liver_output) else: liver_output = Activation('linear', name='output_1')(liver_output_pre) # Final model. model = Model(inputs=input_multi_slice, outputs=[lesion_output, liver_output]) return model
def assemble_model(input_shape, num_classes, num_init_blocks, num_main_blocks, main_block_depth, input_num_filters, num_cycles=1, preprocessor_network=None, postprocessor_network=None, mainblock=None, initblock=None, nonlinearity='relu', dropout=0., normalization=BatchNormalization, weight_norm=False, weight_decay=None, norm_kwargs=None, init='he_normal', ndim=2, cycles_share_weights=True, num_residuals=1, num_first_conv=1, num_final_conv=1, num_classifier=1, num_outputs=1, use_first_conv=True, use_final_conv=True): """ input_shape : tuple specifiying the 2D image input shape. num_classes : number of classes in the segmentation output. num_init_blocks : the number of blocks of type initblock, above mainblocks. These blocks always have the same number of channels as the first convolutional layer in the model. num_main_blocks : the number of blocks of type mainblock, below initblocks. These blocks double (halve) in number of channels at each downsampling (upsampling). main_block_depth : an integer or list of integers specifying the number of repetitions of each mainblock. A list must contain as many values as there are main_blocks in the downward (or upward -- it's mirrored) path plus one for the across path. input_num_filters : the number channels in the first (last) convolutional layer in the model (and of each initblock). num_cycles : number of times to cycle the down/up processing pair. preprocessor_network : a neural network for preprocessing the input data. postprocessor_network : a neural network for postprocessing the data fed to the classifier. mainblock : a layer defining the mainblock (bottleneck by default). initblock : a layer defining the initblock (basic_block_mp by default). nonlinearity : string or function specifying/defining the nonlinearity. dropout : the dropout probability, introduced in every block. normalization : the normalization to apply to layers (by default: batch normalization). If None, no normalization is applied. weight_norm : boolean, whether to use weight norm on conv layers. weight_decay : the weight decay (L2 penalty) used in every convolution. norm_kwargs : keyword arguments to pass to batch norm layers. init : string or function specifying the initializer for layers. ndim : the spatial dimensionality of the input and output (2 or 3) cycles_share_weights : share network weights across cycles. num_residuals : the number of parallel residual functions per block. num_first_conv : the number of parallel first convolutions. num_final_conv : the number of parallel final convolutions (+BN). num_classifier : the number of parallel linear classifiers. num_outputs : the number of model outputs, each with num_classifier classifiers. """ ''' By default, use depth 2 basic_block for mainblock ''' if mainblock is None: mainblock = basic_block if initblock is None: initblock = basic_block_mp ''' main_block_depth can be a list per block or a single value -- ensure the list length is correct (if list) and that no length is 0 ''' if not hasattr(main_block_depth, '__len__'): if main_block_depth == 0: raise ValueError("main_block_depth must never be zero") else: if len(main_block_depth) != num_main_blocks + 1: raise ValueError("main_block_depth must have " "`num_main_blocks+1` values when " "passed as a list") for d in main_block_depth: if d == 0: raise ValueError("main_block_depth must never be zero") ''' Returns the depth of a mainblock for a given pooling level. ''' def get_repetitions(level): if hasattr(main_block_depth, '__len__'): return main_block_depth[level] return main_block_depth ''' Merge tensors, changing the number of feature maps in the first input to match that of the second input. Feature maps in the first input are reweighted. If weight sharing is enabled, reuse old convolutions. ''' def merge_into(x, into, skips, cycle, direction, depth): if x._keras_shape[1] != into._keras_shape[1]: if cycles_share_weights and depth in skips[cycle - 1][direction]: conv_layer = skips[cycle - 1][direction][depth] else: name = _unique('long_skip_' + str(direction) + '_' + str(depth)) conv_layer = Convolution(filters=into._keras_shape[1], kernel_size=1, ndim=ndim, weight_norm=weight_norm, kernel_initializer=init, padding='valid', kernel_regularizer=_l2(weight_decay), name=name) skips[cycle][direction][depth] = conv_layer x = conv_layer(x) out = merge_add([x, into]) if normalization is None: # Divide sum by two. out = Lambda(lambda x: x / 2., output_shape=lambda x: x)(out) return out ''' Given some block function and an input tensor, return a reusable model instantiating that block function. This is to allow weight sharing. ''' def make_block(block_func, x): x_filters = x._keras_shape[1] input = Input(shape=(x_filters, ) + tuple([None] * ndim)) model = Model(input, block_func(input)) return model ''' Constant kwargs passed to the init and main blocks. ''' block_kwargs = { 'skip': True, 'dropout': dropout, 'weight_norm': weight_norm, 'weight_decay': weight_decay, 'num_residuals': num_residuals, 'norm_kwargs': norm_kwargs, 'nonlinearity': nonlinearity, 'init': init, 'ndim': ndim } if norm_kwargs is None: norm_kwargs = {} # INPUT input = Input(shape=input_shape) # Preprocessing if preprocessor_network is not None: input = preprocessor_network(input) ''' Build the blocks for all cycles, contracting and expanding in each cycle. ''' tensors = [] # feature tensors blocks = [] # residual block layers skips = [] # 1x1 kernel convolution layers on long skip connections x = input for cycle in range(num_cycles): # Create tensors and layer lists for this cycle. tensors.append({'down': {}, 'up': {}, 'across': {}}) blocks.append({'down': {}, 'up': {}, 'across': {}}) skips.append({'down': {}, 'up': {}, 'across': {}}) # First convolution if cycle > 0: x = merge_into(x, tensors[cycle - 1]['up'][0], skips=skips, cycle=cycle, direction='down', depth=0) if cycles_share_weights and cycle > 1: block = blocks[cycle - 1]['down'][0] else: def first_block(x): outputs = [] for i in range(num_first_conv): out = Convolution(filters=input_num_filters, kernel_size=3, ndim=ndim, weight_norm=weight_norm, kernel_initializer=init, padding='same', kernel_regularizer=_l2(weight_decay), name=_unique('first_conv_' + str(i)))(x) outputs.append(out) if len(outputs) > 1: out = merge_add(outputs) if normalization is None: # Divide sum by two. out = Lambda(lambda x: x / 2., output_shape=lambda x: x)(out) else: out = outputs[0] return out block = make_block(first_block, x) if use_first_conv: x = block(x) blocks[cycle]['down'][0] = block else: blocks[cycle]['down'][0] = lambda x: x tensors[cycle]['down'][0] = x print("Cycle {} - FIRST DOWN: {}".format(cycle, x._keras_shape)) # DOWN (initial subsampling blocks) for b in range(num_init_blocks): depth = b + 1 if cycle > 0: x = merge_into(x, tensors[cycle - 1]['up'][depth], skips=skips, cycle=cycle, direction='down', depth=depth) if cycles_share_weights and cycle > 1: block = blocks[cycle - 1]['down'][depth] else: block_func = residual_block(initblock, filters=input_num_filters, repetitions=1, subsample=True, upsample=False, normalization=normalization, name='d' + str(depth), **block_kwargs) block = make_block(block_func, x) x = block(x) blocks[cycle]['down'][depth] = block tensors[cycle]['down'][depth] = x print("Cycle {} - INIT DOWN {}: {}".format(cycle, b, x._keras_shape)) # DOWN (resnet blocks) for b in range(num_main_blocks): depth = b + 1 + num_init_blocks if cycle > 0: x = merge_into(x, tensors[cycle - 1]['up'][depth], skips=skips, cycle=cycle, direction='down', depth=depth) if cycles_share_weights and cycle > 1: block = blocks[cycle - 1]['down'][depth] else: block_func = residual_block(mainblock, filters=input_num_filters * (2**b), repetitions=get_repetitions(b), subsample=True, upsample=False, normalization=normalization, name='d' + str(depth), **block_kwargs) block = make_block(block_func, x) x = block(x) blocks[cycle]['down'][depth] = block tensors[cycle]['down'][depth] = x print("Cycle {} - MAIN DOWN {}: {}".format(cycle, b, x._keras_shape)) # ACROSS if num_main_blocks: if cycle > 0: x = merge_into(x, tensors[cycle - 1]['across'][0], skips=skips, cycle=cycle, direction='across', depth=0) if cycles_share_weights and cycle > 1: block = blocks[cycle - 1]['across'][0] else: block_func = residual_block( \ mainblock, filters=input_num_filters*(2**b), repetitions=get_repetitions(num_main_blocks), subsample=True, upsample=True, normalization=normalization, name='a', **block_kwargs) block = make_block(block_func, x) x = block(x) blocks[cycle]['across'][0] = block tensors[cycle]['across'][0] = x print("Cycle {} - ACROSS: {}".format(cycle, x._keras_shape)) # UP (resnet blocks) for b in range(num_main_blocks - 1, -1, -1): depth = b + 1 + num_init_blocks x = merge_into(x, tensors[cycle]['down'][depth], skips=skips, cycle=cycle, direction='up', depth=depth) if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1: block = blocks[cycle - 1]['up'][depth] else: block_func = residual_block(mainblock, filters=input_num_filters * (2**b), repetitions=get_repetitions(b), subsample=False, upsample=True, normalization=normalization, name='u' + str(depth), **block_kwargs) block = make_block(block_func, x) x = block(x) blocks[cycle]['up'][depth] = block tensors[cycle]['up'][depth] = x print("Cycle {} - MAIN UP {}: {}".format(cycle, b, x._keras_shape)) # UP (final upsampling blocks) for b in range(num_init_blocks - 1, -1, -1): depth = b + 1 x = merge_into(x, tensors[cycle]['down'][depth], skips=skips, cycle=cycle, direction='up', depth=depth) if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1: block = blocks[cycle - 1]['up'][depth] else: block_func = residual_block(initblock, filters=input_num_filters, repetitions=1, subsample=False, upsample=True, normalization=normalization, name='u' + str(depth), **block_kwargs) block = make_block(block_func, x) x = block(x) blocks[cycle]['up'][depth] = block tensors[cycle]['up'][depth] = x print("Cycle {} - INIT UP {}: {}".format(cycle, b, x._keras_shape)) # Final convolution. x = merge_into(x, tensors[cycle]['down'][0], skips=skips, cycle=cycle, direction='up', depth=0) if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1: block = blocks[cycle - 1]['up'][0] else: def final_block(x): outputs = [] for i in range(num_final_conv): out = Convolution(filters=input_num_filters, kernel_size=3, ndim=ndim, weight_norm=weight_norm, kernel_initializer=init, padding='same', kernel_regularizer=_l2(weight_decay), name=_unique('final_conv_' + str(i)))(x) if normalization is not None: out = normalization(name=_unique('final_norm_' + str(i)), **norm_kwargs)(out) out = get_nonlinearity(nonlinearity)(out) outputs.append(out) if len(outputs) > 1: out = merge_add(outputs) else: out = outputs[0] return out block = make_block(final_block, x) if use_final_conv: x = block(x) blocks[cycle]['up'][0] = block else: blocks[cycle]['up'][0] = lambda x: x tensors[cycle]['up'][0] = x if cycle > 0: # Merge preclassifier outputs across all cycles. x = merge_into(x, tensors[cycle - 1]['up'][0], skips=skips, cycle=cycle, direction='up', depth=-1) print("Cycle {} - FIRST UP: {}".format(cycle, x._keras_shape)) # Postprocessing if postprocessor_network is not None: x = postprocessor_network(x) # OUTPUTs (SIGMOID) all_outputs = [] if num_classes is not None: for i in range(num_outputs): # Linear classifier classifiers = [] for j in range(num_classifier): name = 'classifier_conv_' + str(j) if i > 0: # backwards compatibility name += '_out' + str(i) output = Convolution(filters=num_classes, kernel_size=1, ndim=ndim, activation='linear', kernel_regularizer=_l2(weight_decay), name=_unique(name))(x) classifiers.append(output) if len(classifiers) > 1: output = merge_add(classifiers) else: output = classifiers[0] if ndim == 2: output = Permute((2, 3, 1))(output) else: output = Permute((2, 3, 4, 1))(output) if num_classes == 1: output = Activation('sigmoid', name='sigmoid' + str(i))(output) else: output = Activation(_softmax, name='softmax' + str(i))(output) if ndim == 2: output_layer = Permute((3, 1, 2)) else: output_layer = Permute((4, 1, 2, 3)) output_layer.name = 'output_' + str(i) output = output_layer(output) all_outputs.append(output) else: # No classifier all_outputs = Activation('linear', name='output_0')(x) # MODEL model = Model(inputs=input, outputs=all_outputs) return model
def assemble_model_two_levels(adversarial=False, num_residuals_bottom=None, discriminator_kwargs=None, **model_kwargs): assert (model_kwargs['num_outputs'] == 2) if discriminator_kwargs is None: discriminator_kwargs = {} input_shape = model_kwargs['input_shape'] model_input = Input(shape=input_shape, name='model_input') # Assemble first model (liver) model_liver_kwargs = copy.copy(model_kwargs) model_liver_kwargs['num_classes'] = None if num_residuals_bottom is not None: model_liver_kwargs['num_residuals'] = num_residuals_bottom model_liver = assemble_cycled_model(**model_liver_kwargs) liver_output_pre = model_liver(model_input) # Assemble second model on top (lesion) model_lesion_kwargs = copy.copy(model_kwargs) model_lesion_kwargs['num_outputs'] = 1 model_lesion_kwargs['input_shape'] = (liver_output_pre._keras_shape[1]\ +input_shape[-3],)+input_shape[1:] model_lesion = assemble_cycled_model(**model_lesion_kwargs) # Connect first model to second lesion_input = merge_concatenate([model_input, liver_output_pre], axis=1) # Create classifier for liver if model_kwargs['num_classes'] is not None: liver_output = Convolution(filters=1, kernel_size=1, ndim=model_kwargs['ndim'], activation='linear', kernel_regularizer=_l2( model_kwargs['weight_decay']), name='classifier_conv_1') liver_output = liver_output(liver_output_pre) if model_kwargs['ndim'] == 2: liver_output = Permute((2, 3, 1))(liver_output) else: liver_output = Permute((2, 3, 4, 1))(liver_output) liver_output = Activation('sigmoid', name='sigmoid_1')(liver_output) if model_kwargs['ndim'] == 2: liver_output_layer = Permute((3, 1, 2)) else: liver_output_layer = Permute((4, 1, 2, 3)) liver_output_layer.name = 'output_1' liver_output = liver_output_layer(liver_output) else: liver_output = Activation('linear', name='output_1')(liver_output_pre) # Lesion classifier output model_lesion.name = 'output_0' lesion_output = model_lesion(lesion_input) # Create discriminators if adversarial: def make_trainable(model, trainable=True): for l in model.layers: if isinstance(l, Model): make_trainable(l, trainable) else: l.trainable = trainable # Assemble discriminators. disc_0 = assemble_cnn(**discriminator_kwargs) disc_1 = assemble_cnn(**discriminator_kwargs) # Create discriminator outputs for real data. input_disc_0_seg = Input(input_shape, name='input_disc_0_seg') input_disc_1_seg = Input(input_shape, name='input_disc_1_seg') input_disc_0 = merge_concatenate([input_disc_0_seg, model_input], axis=1) input_disc_1 = merge_concatenate([input_disc_1_seg, model_input], axis=1) out_disc_0 = disc_0(input_disc_0) out_disc_1 = disc_1(input_disc_1) # Create untrainable segmentation generator output. model_gen = Model(inputs=model_input, outputs=[lesion_output, liver_output]) make_trainable(model_gen, False) outputs_gen = model_gen(model_input) # Create discriminator outputs for training the discriminators. input_disc_0 = merge_concatenate([outputs_gen[0], model_input], axis=1) input_disc_1 = merge_concatenate([outputs_gen[1], model_input], axis=1) out_adv_0_d = disc_0(input_disc_0) out_adv_1_d = disc_1(input_disc_1) # Make discriminators untrainable, generator trainable. make_trainable(model_gen, True) make_trainable(disc_0, False) make_trainable(disc_1, False) # Create discriminator outputs for training the generator. outputs_gen = model_gen(model_input) input_disc_0 = merge_concatenate([outputs_gen[0], model_input], axis=1) input_disc_1 = merge_concatenate([outputs_gen[1], model_input], axis=1) out_adv_0_g = disc_0(input_disc_0) out_adv_1_g = disc_1(input_disc_1) # Name the outputs. def name_layer(tensor, name): return Activation('linear', name=name)(tensor) out_adv_0_d = name_layer(out_adv_0_d, 'out_adv_0_d') out_adv_1_d = name_layer(out_adv_1_d, 'out_adv_1_d') out_adv_0_g = name_layer(out_adv_0_g, 'out_adv_0_g') out_adv_1_g = name_layer(out_adv_1_g, 'out_adv_1_g') out_disc_0 = name_layer(out_disc_0, 'out_disc_0') out_disc_1 = name_layer(out_disc_1, 'out_disc_1') # Create aggregate model if adversarial: model = Model( \ inputs=[model_input, input_disc_0_seg, input_disc_1_seg], outputs=[lesion_output, liver_output, out_adv_0_d, out_adv_1_d, out_adv_0_g, out_adv_1_g, out_disc_0, out_disc_1]) else: model = Model(inputs=model_input, outputs=[lesion_output, liver_output]) return model