Beispiel #1
0
def assemble_model_multi_slice(ms_ndim_out=3, **model_kwargs):
    assert (model_kwargs['ndim'] == 3)

    # Assemble base model.
    input_shape = model_kwargs['input_shape']
    input_shape_2D = (input_shape[0], ) + input_shape[2:]
    model_kwargs_2D = copy.copy(model_kwargs)
    model_kwargs_2D['ndim'] = 2
    model_kwargs_2D['input_shape'] = input_shape_2D
    base_model = assemble_base_model(**model_kwargs_2D)

    # Instantiate parallel models, sharing weights.
    # NOTE: batch norm statistics are shared!
    input_multi_slice = Input(input_shape)
    lesion_output_pre = []
    liver_output_pre = []
    z_axis = 2

    def select(i):
        return Lambda(lambda x: x[:, :, i, :, :], output_shape=input_shape_2D)

    def expand():
        output_shape = (
            model_kwargs['input_num_filters'],
            1,
        ) + input_shape_2D[1:]
        return Lambda(lambda x: K.expand_dims(x, axis=z_axis),
                      output_shape=output_shape)

    for i in range(3):
        out_0, out_1 = base_model(select(i)(input_multi_slice))
        lesion_output_pre.append(expand()(out_0))
        liver_output_pre.append(expand()(out_1))
    lesion_output_pre = merge_concatenate(lesion_output_pre, axis=z_axis)
    liver_output_pre = merge_concatenate(liver_output_pre, axis=z_axis)
    if ms_ndim_out == 2:
        flat_shape = (model_kwargs['input_num_filters']*3,)\
                     +input_shape_2D[1:]
        lesion_output_pre = Reshape(flat_shape)(lesion_output_pre)
        liver_output_pre = Reshape(flat_shape)(liver_output_pre)

    # Add convolutions to combine information across slices.
    nonlinearity = model_kwargs['nonlinearity']
    lesion_output_pre = Convolution( \
        filters=model_kwargs['input_num_filters'],
        kernel_size=3,
        ndim=ms_ndim_out,
        padding='same',
        weight_norm=model_kwargs['weight_norm'],
        kernel_regularizer=_l2(model_kwargs['weight_decay']),
        name='conv_3D_0')(lesion_output_pre)
    lesion_output_pre = get_nonlinearity(nonlinearity)(lesion_output_pre)
    liver_output_pre = Convolution( \
        filters=model_kwargs['input_num_filters'],
        kernel_size=3,
        ndim=ms_ndim_out,
        padding='same',
        weight_norm=model_kwargs['weight_norm'],
        kernel_regularizer=_l2(model_kwargs['weight_decay']),
        name='conv_3D_1')(liver_output_pre)
    liver_output_pre = get_nonlinearity(nonlinearity)(liver_output_pre)

    # Create classifier for lesion.
    if model_kwargs['num_classes'] is not None:
        lesion_output = Convolution(filters=1,
                                    kernel_size=1,
                                    ndim=ms_ndim_out,
                                    activation='linear',
                                    kernel_regularizer=_l2(
                                        model_kwargs['weight_decay']),
                                    name='classifier_conv_0')
        lesion_output = lesion_output(lesion_output_pre)
        if ms_ndim_out == 2:
            lesion_output = Permute((2, 3, 1))(lesion_output)
        else:
            lesion_output = Permute((2, 3, 4, 1))(lesion_output)
        lesion_output = Activation('sigmoid', name='sigmoid_0')(lesion_output)
        if ms_ndim_out == 2:
            lesion_output_layer = Permute((3, 1, 2))
        else:
            lesion_output_layer = Permute((4, 1, 2, 3))
        lesion_output_layer.name = 'output_0'
        lesion_output = lesion_output_layer(lesion_output)
    else:
        lesion_output = Activation('linear',
                                   name='output_0')(lesion_output_pre)

    # Create classifier for liver.
    if model_kwargs['num_classes'] is not None:
        liver_output = Convolution(filters=1,
                                   kernel_size=1,
                                   ndim=ms_ndim_out,
                                   activation='linear',
                                   kernel_regularizer=_l2(
                                       model_kwargs['weight_decay']),
                                   name='classifier_conv_1')
        liver_output = liver_output(liver_output_pre)
        if ms_ndim_out == 2:
            liver_output = Permute((2, 3, 1))(liver_output)
        else:
            liver_output = Permute((2, 3, 4, 1))(liver_output)
        liver_output = Activation('sigmoid', name='sigmoid_1')(liver_output)
        if ms_ndim_out == 2:
            liver_output_layer = Permute((3, 1, 2))
        else:
            liver_output_layer = Permute((4, 1, 2, 3))
        liver_output_layer.name = 'output_1'
        liver_output = liver_output_layer(liver_output)
    else:
        liver_output = Activation('linear', name='output_1')(liver_output_pre)

    # Final model.
    model = Model(inputs=input_multi_slice,
                  outputs=[lesion_output, liver_output])
    return model
Beispiel #2
0
def assemble_model(input_shape,
                   num_classes,
                   num_init_blocks,
                   num_main_blocks,
                   main_block_depth,
                   input_num_filters,
                   num_cycles=1,
                   preprocessor_network=None,
                   postprocessor_network=None,
                   mainblock=None,
                   initblock=None,
                   nonlinearity='relu',
                   dropout=0.,
                   normalization=BatchNormalization,
                   weight_norm=False,
                   weight_decay=None,
                   norm_kwargs=None,
                   init='he_normal',
                   ndim=2,
                   cycles_share_weights=True,
                   num_residuals=1,
                   num_first_conv=1,
                   num_final_conv=1,
                   num_classifier=1,
                   num_outputs=1,
                   use_first_conv=True,
                   use_final_conv=True):
    """
    input_shape : tuple specifiying the 2D image input shape.
    num_classes : number of classes in the segmentation output.
    num_init_blocks : the number of blocks of type initblock, above mainblocks.
        These blocks always have the same number of channels as the first
        convolutional layer in the model.
    num_main_blocks : the number of blocks of type mainblock, below initblocks.
        These blocks double (halve) in number of channels at each downsampling
        (upsampling).
    main_block_depth : an integer or list of integers specifying the number of
        repetitions of each mainblock. A list must contain as many values as
        there are main_blocks in the downward (or upward -- it's mirrored) path
        plus one for the across path.
    input_num_filters : the number channels in the first (last) convolutional
        layer in the model (and of each initblock).
    num_cycles : number of times to cycle the down/up processing pair.
    preprocessor_network : a neural network for preprocessing the input data.
    postprocessor_network : a neural network for postprocessing the data fed
        to the classifier.
    mainblock : a layer defining the mainblock (bottleneck by default).
    initblock : a layer defining the initblock (basic_block_mp by default).
    nonlinearity : string or function specifying/defining the nonlinearity.
    dropout : the dropout probability, introduced in every block.
    normalization : the normalization to apply to layers (by default: batch
        normalization). If None, no normalization is applied.
    weight_norm : boolean, whether to use weight norm on conv layers.
    weight_decay : the weight decay (L2 penalty) used in every convolution.
    norm_kwargs : keyword arguments to pass to batch norm layers.
    init : string or function specifying the initializer for layers.
    ndim : the spatial dimensionality of the input and output (2 or 3)
    cycles_share_weights : share network weights across cycles.
    num_residuals : the number of parallel residual functions per block.
    num_first_conv : the number of parallel first convolutions.
    num_final_conv : the number of parallel final convolutions (+BN).
    num_classifier : the number of parallel linear classifiers.
    num_outputs : the number of model outputs, each with num_classifier
        classifiers.
    """
    '''
    By default, use depth 2 basic_block for mainblock
    '''
    if mainblock is None:
        mainblock = basic_block
    if initblock is None:
        initblock = basic_block_mp
    '''
    main_block_depth can be a list per block or a single value 
    -- ensure the list length is correct (if list) and that no length is 0
    '''
    if not hasattr(main_block_depth, '__len__'):
        if main_block_depth == 0:
            raise ValueError("main_block_depth must never be zero")
    else:
        if len(main_block_depth) != num_main_blocks + 1:
            raise ValueError("main_block_depth must have "
                             "`num_main_blocks+1` values when "
                             "passed as a list")
        for d in main_block_depth:
            if d == 0:
                raise ValueError("main_block_depth must never be zero")
    '''
    Returns the depth of a mainblock for a given pooling level.
    '''
    def get_repetitions(level):
        if hasattr(main_block_depth, '__len__'):
            return main_block_depth[level]
        return main_block_depth

    '''
    Merge tensors, changing the number of feature maps in the first input
    to match that of the second input. Feature maps in the first input are
    reweighted.
    
    If weight sharing is enabled, reuse old convolutions.
    '''

    def merge_into(x, into, skips, cycle, direction, depth):
        if x._keras_shape[1] != into._keras_shape[1]:
            if cycles_share_weights and depth in skips[cycle - 1][direction]:
                conv_layer = skips[cycle - 1][direction][depth]
            else:
                name = _unique('long_skip_' + str(direction) + '_' +
                               str(depth))
                conv_layer = Convolution(filters=into._keras_shape[1],
                                         kernel_size=1,
                                         ndim=ndim,
                                         weight_norm=weight_norm,
                                         kernel_initializer=init,
                                         padding='valid',
                                         kernel_regularizer=_l2(weight_decay),
                                         name=name)
            skips[cycle][direction][depth] = conv_layer
            x = conv_layer(x)
        out = merge_add([x, into])
        if normalization is None:
            # Divide sum by two.
            out = Lambda(lambda x: x / 2., output_shape=lambda x: x)(out)
        return out

    '''
    Given some block function and an input tensor, return a reusable model
    instantiating that block function. This is to allow weight sharing.
    '''

    def make_block(block_func, x):
        x_filters = x._keras_shape[1]
        input = Input(shape=(x_filters, ) + tuple([None] * ndim))
        model = Model(input, block_func(input))
        return model

    '''
    Constant kwargs passed to the init and main blocks.
    '''
    block_kwargs = {
        'skip': True,
        'dropout': dropout,
        'weight_norm': weight_norm,
        'weight_decay': weight_decay,
        'num_residuals': num_residuals,
        'norm_kwargs': norm_kwargs,
        'nonlinearity': nonlinearity,
        'init': init,
        'ndim': ndim
    }
    if norm_kwargs is None:
        norm_kwargs = {}

    # INPUT
    input = Input(shape=input_shape)

    # Preprocessing
    if preprocessor_network is not None:
        input = preprocessor_network(input)
    '''
    Build the blocks for all cycles, contracting and expanding in each cycle.
    '''
    tensors = []  # feature tensors
    blocks = []  # residual block layers
    skips = []  # 1x1 kernel convolution layers on long skip connections
    x = input
    for cycle in range(num_cycles):
        # Create tensors and layer lists for this cycle.
        tensors.append({'down': {}, 'up': {}, 'across': {}})
        blocks.append({'down': {}, 'up': {}, 'across': {}})
        skips.append({'down': {}, 'up': {}, 'across': {}})

        # First convolution
        if cycle > 0:
            x = merge_into(x,
                           tensors[cycle - 1]['up'][0],
                           skips=skips,
                           cycle=cycle,
                           direction='down',
                           depth=0)
        if cycles_share_weights and cycle > 1:
            block = blocks[cycle - 1]['down'][0]
        else:

            def first_block(x):
                outputs = []
                for i in range(num_first_conv):
                    out = Convolution(filters=input_num_filters,
                                      kernel_size=3,
                                      ndim=ndim,
                                      weight_norm=weight_norm,
                                      kernel_initializer=init,
                                      padding='same',
                                      kernel_regularizer=_l2(weight_decay),
                                      name=_unique('first_conv_' + str(i)))(x)
                    outputs.append(out)
                if len(outputs) > 1:
                    out = merge_add(outputs)
                    if normalization is None:
                        # Divide sum by two.
                        out = Lambda(lambda x: x / 2.,
                                     output_shape=lambda x: x)(out)
                else:
                    out = outputs[0]
                return out

            block = make_block(first_block, x)
        if use_first_conv:
            x = block(x)
            blocks[cycle]['down'][0] = block
        else:
            blocks[cycle]['down'][0] = lambda x: x
        tensors[cycle]['down'][0] = x
        print("Cycle {} - FIRST DOWN: {}".format(cycle, x._keras_shape))

        # DOWN (initial subsampling blocks)
        for b in range(num_init_blocks):
            depth = b + 1
            if cycle > 0:
                x = merge_into(x,
                               tensors[cycle - 1]['up'][depth],
                               skips=skips,
                               cycle=cycle,
                               direction='down',
                               depth=depth)
            if cycles_share_weights and cycle > 1:
                block = blocks[cycle - 1]['down'][depth]
            else:
                block_func = residual_block(initblock,
                                            filters=input_num_filters,
                                            repetitions=1,
                                            subsample=True,
                                            upsample=False,
                                            normalization=normalization,
                                            name='d' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['down'][depth] = block
            tensors[cycle]['down'][depth] = x
            print("Cycle {} - INIT DOWN {}: {}".format(cycle, b,
                                                       x._keras_shape))

        # DOWN (resnet blocks)
        for b in range(num_main_blocks):
            depth = b + 1 + num_init_blocks
            if cycle > 0:
                x = merge_into(x,
                               tensors[cycle - 1]['up'][depth],
                               skips=skips,
                               cycle=cycle,
                               direction='down',
                               depth=depth)
            if cycles_share_weights and cycle > 1:
                block = blocks[cycle - 1]['down'][depth]
            else:
                block_func = residual_block(mainblock,
                                            filters=input_num_filters * (2**b),
                                            repetitions=get_repetitions(b),
                                            subsample=True,
                                            upsample=False,
                                            normalization=normalization,
                                            name='d' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['down'][depth] = block
            tensors[cycle]['down'][depth] = x
            print("Cycle {} - MAIN DOWN {}: {}".format(cycle, b,
                                                       x._keras_shape))

        # ACROSS
        if num_main_blocks:
            if cycle > 0:
                x = merge_into(x,
                               tensors[cycle - 1]['across'][0],
                               skips=skips,
                               cycle=cycle,
                               direction='across',
                               depth=0)
            if cycles_share_weights and cycle > 1:
                block = blocks[cycle - 1]['across'][0]
            else:
                block_func = residual_block( \
                                  mainblock,
                                  filters=input_num_filters*(2**b),
                                  repetitions=get_repetitions(num_main_blocks),
                                  subsample=True,
                                  upsample=True,
                                  normalization=normalization,
                                  name='a',
                                  **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['across'][0] = block
            tensors[cycle]['across'][0] = x
            print("Cycle {} - ACROSS: {}".format(cycle, x._keras_shape))

        # UP (resnet blocks)
        for b in range(num_main_blocks - 1, -1, -1):
            depth = b + 1 + num_init_blocks
            x = merge_into(x,
                           tensors[cycle]['down'][depth],
                           skips=skips,
                           cycle=cycle,
                           direction='up',
                           depth=depth)
            if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1:
                block = blocks[cycle - 1]['up'][depth]
            else:

                block_func = residual_block(mainblock,
                                            filters=input_num_filters * (2**b),
                                            repetitions=get_repetitions(b),
                                            subsample=False,
                                            upsample=True,
                                            normalization=normalization,
                                            name='u' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['up'][depth] = block
            tensors[cycle]['up'][depth] = x
            print("Cycle {} - MAIN UP {}: {}".format(cycle, b, x._keras_shape))

        # UP (final upsampling blocks)
        for b in range(num_init_blocks - 1, -1, -1):
            depth = b + 1
            x = merge_into(x,
                           tensors[cycle]['down'][depth],
                           skips=skips,
                           cycle=cycle,
                           direction='up',
                           depth=depth)
            if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1:
                block = blocks[cycle - 1]['up'][depth]
            else:
                block_func = residual_block(initblock,
                                            filters=input_num_filters,
                                            repetitions=1,
                                            subsample=False,
                                            upsample=True,
                                            normalization=normalization,
                                            name='u' + str(depth),
                                            **block_kwargs)
                block = make_block(block_func, x)
            x = block(x)
            blocks[cycle]['up'][depth] = block
            tensors[cycle]['up'][depth] = x
            print("Cycle {} - INIT UP {}: {}".format(cycle, b, x._keras_shape))

        # Final convolution.
        x = merge_into(x,
                       tensors[cycle]['down'][0],
                       skips=skips,
                       cycle=cycle,
                       direction='up',
                       depth=0)
        if cycles_share_weights and cycle > 0 and cycle < num_cycles - 1:
            block = blocks[cycle - 1]['up'][0]
        else:

            def final_block(x):
                outputs = []
                for i in range(num_final_conv):
                    out = Convolution(filters=input_num_filters,
                                      kernel_size=3,
                                      ndim=ndim,
                                      weight_norm=weight_norm,
                                      kernel_initializer=init,
                                      padding='same',
                                      kernel_regularizer=_l2(weight_decay),
                                      name=_unique('final_conv_' + str(i)))(x)
                    if normalization is not None:
                        out = normalization(name=_unique('final_norm_' +
                                                         str(i)),
                                            **norm_kwargs)(out)
                    out = get_nonlinearity(nonlinearity)(out)
                    outputs.append(out)
                if len(outputs) > 1:
                    out = merge_add(outputs)
                else:
                    out = outputs[0]
                return out

            block = make_block(final_block, x)
        if use_final_conv:
            x = block(x)
            blocks[cycle]['up'][0] = block
        else:
            blocks[cycle]['up'][0] = lambda x: x
        tensors[cycle]['up'][0] = x
        if cycle > 0:
            # Merge preclassifier outputs across all cycles.
            x = merge_into(x,
                           tensors[cycle - 1]['up'][0],
                           skips=skips,
                           cycle=cycle,
                           direction='up',
                           depth=-1)
        print("Cycle {} - FIRST UP: {}".format(cycle, x._keras_shape))

    # Postprocessing
    if postprocessor_network is not None:
        x = postprocessor_network(x)

    # OUTPUTs (SIGMOID)
    all_outputs = []
    if num_classes is not None:
        for i in range(num_outputs):
            # Linear classifier
            classifiers = []
            for j in range(num_classifier):
                name = 'classifier_conv_' + str(j)
                if i > 0:
                    # backwards compatibility
                    name += '_out' + str(i)
                output = Convolution(filters=num_classes,
                                     kernel_size=1,
                                     ndim=ndim,
                                     activation='linear',
                                     kernel_regularizer=_l2(weight_decay),
                                     name=_unique(name))(x)
                classifiers.append(output)
            if len(classifiers) > 1:
                output = merge_add(classifiers)
            else:
                output = classifiers[0]
            if ndim == 2:
                output = Permute((2, 3, 1))(output)
            else:
                output = Permute((2, 3, 4, 1))(output)
            if num_classes == 1:
                output = Activation('sigmoid', name='sigmoid' + str(i))(output)
            else:
                output = Activation(_softmax, name='softmax' + str(i))(output)
            if ndim == 2:
                output_layer = Permute((3, 1, 2))
            else:
                output_layer = Permute((4, 1, 2, 3))
            output_layer.name = 'output_' + str(i)
            output = output_layer(output)
            all_outputs.append(output)
    else:
        # No classifier
        all_outputs = Activation('linear', name='output_0')(x)

    # MODEL
    model = Model(inputs=input, outputs=all_outputs)

    return model
Beispiel #3
0
def assemble_model_two_levels(adversarial=False,
                              num_residuals_bottom=None,
                              discriminator_kwargs=None,
                              **model_kwargs):
    assert (model_kwargs['num_outputs'] == 2)

    if discriminator_kwargs is None:
        discriminator_kwargs = {}

    input_shape = model_kwargs['input_shape']
    model_input = Input(shape=input_shape, name='model_input')

    # Assemble first model (liver)
    model_liver_kwargs = copy.copy(model_kwargs)
    model_liver_kwargs['num_classes'] = None
    if num_residuals_bottom is not None:
        model_liver_kwargs['num_residuals'] = num_residuals_bottom
    model_liver = assemble_cycled_model(**model_liver_kwargs)
    liver_output_pre = model_liver(model_input)

    # Assemble second model on top (lesion)
    model_lesion_kwargs = copy.copy(model_kwargs)
    model_lesion_kwargs['num_outputs'] = 1
    model_lesion_kwargs['input_shape'] = (liver_output_pre._keras_shape[1]\
                                          +input_shape[-3],)+input_shape[1:]
    model_lesion = assemble_cycled_model(**model_lesion_kwargs)

    # Connect first model to second
    lesion_input = merge_concatenate([model_input, liver_output_pre], axis=1)

    # Create classifier for liver
    if model_kwargs['num_classes'] is not None:
        liver_output = Convolution(filters=1,
                                   kernel_size=1,
                                   ndim=model_kwargs['ndim'],
                                   activation='linear',
                                   kernel_regularizer=_l2(
                                       model_kwargs['weight_decay']),
                                   name='classifier_conv_1')
        liver_output = liver_output(liver_output_pre)
        if model_kwargs['ndim'] == 2:
            liver_output = Permute((2, 3, 1))(liver_output)
        else:
            liver_output = Permute((2, 3, 4, 1))(liver_output)
        liver_output = Activation('sigmoid', name='sigmoid_1')(liver_output)
        if model_kwargs['ndim'] == 2:
            liver_output_layer = Permute((3, 1, 2))
        else:
            liver_output_layer = Permute((4, 1, 2, 3))
        liver_output_layer.name = 'output_1'
        liver_output = liver_output_layer(liver_output)
    else:
        liver_output = Activation('linear', name='output_1')(liver_output_pre)

    # Lesion classifier output
    model_lesion.name = 'output_0'
    lesion_output = model_lesion(lesion_input)

    # Create discriminators
    if adversarial:

        def make_trainable(model, trainable=True):
            for l in model.layers:
                if isinstance(l, Model):
                    make_trainable(l, trainable)
                else:
                    l.trainable = trainable

        # Assemble discriminators.
        disc_0 = assemble_cnn(**discriminator_kwargs)
        disc_1 = assemble_cnn(**discriminator_kwargs)

        # Create discriminator outputs for real data.
        input_disc_0_seg = Input(input_shape, name='input_disc_0_seg')
        input_disc_1_seg = Input(input_shape, name='input_disc_1_seg')
        input_disc_0 = merge_concatenate([input_disc_0_seg, model_input],
                                         axis=1)
        input_disc_1 = merge_concatenate([input_disc_1_seg, model_input],
                                         axis=1)
        out_disc_0 = disc_0(input_disc_0)
        out_disc_1 = disc_1(input_disc_1)

        # Create untrainable segmentation generator output.
        model_gen = Model(inputs=model_input,
                          outputs=[lesion_output, liver_output])
        make_trainable(model_gen, False)
        outputs_gen = model_gen(model_input)

        # Create discriminator outputs for training the discriminators.
        input_disc_0 = merge_concatenate([outputs_gen[0], model_input], axis=1)
        input_disc_1 = merge_concatenate([outputs_gen[1], model_input], axis=1)
        out_adv_0_d = disc_0(input_disc_0)
        out_adv_1_d = disc_1(input_disc_1)

        # Make discriminators untrainable, generator trainable.
        make_trainable(model_gen, True)
        make_trainable(disc_0, False)
        make_trainable(disc_1, False)

        # Create discriminator outputs for training the generator.
        outputs_gen = model_gen(model_input)
        input_disc_0 = merge_concatenate([outputs_gen[0], model_input], axis=1)
        input_disc_1 = merge_concatenate([outputs_gen[1], model_input], axis=1)
        out_adv_0_g = disc_0(input_disc_0)
        out_adv_1_g = disc_1(input_disc_1)

        # Name the outputs.
        def name_layer(tensor, name):
            return Activation('linear', name=name)(tensor)

        out_adv_0_d = name_layer(out_adv_0_d, 'out_adv_0_d')
        out_adv_1_d = name_layer(out_adv_1_d, 'out_adv_1_d')
        out_adv_0_g = name_layer(out_adv_0_g, 'out_adv_0_g')
        out_adv_1_g = name_layer(out_adv_1_g, 'out_adv_1_g')
        out_disc_0 = name_layer(out_disc_0, 'out_disc_0')
        out_disc_1 = name_layer(out_disc_1, 'out_disc_1')

    # Create aggregate model
    if adversarial:
        model = Model( \
            inputs=[model_input,
                    input_disc_0_seg,
                    input_disc_1_seg],
            outputs=[lesion_output, liver_output,
                     out_adv_0_d, out_adv_1_d,
                     out_adv_0_g, out_adv_1_g,
                     out_disc_0, out_disc_1])
    else:
        model = Model(inputs=model_input,
                      outputs=[lesion_output, liver_output])

    return model