def decoder(task, dwn_tr4, dwn_tr3, dwn_tr2, dwn_tr1, in_tr, nf, bn, dr, net_type,
            io, ao, ds, chn):
    # decoder for segmentation. up_path for segmentation V-Net, with shot connections
    up_tr4, up_tr3, up_tr2, up_tr1 = up_series(task, dwn_tr4, dwn_tr3, dwn_tr2, dwn_tr1, in_tr, nf, bn, dr, net_type)

    # classification
    if task is "recon":
        out = Conv3D(chn, 1, padding='same', name=task + '_out_recon')(up_tr1)
    else:
        res = Conv3D(chn, 1, padding='same', name=task + '_Conv3D_last')(up_tr1)
        out = Activation('softmax', name=task + '_out_segmentation')(res)

    out = [out]  # convert to list to append other outputs
    if "2_out" in io:
        res_2 = Conv3D(chn, 1, padding='same', name=task + '_Conv3D_last2')(up_tr1)
        out_2 = Activation('softmax', name=task + '_out_segmentation2')(res_2)
        out.append(out_2)
    if ao:
        # aux_output
        aux_res = Conv3D(2, 1, padding='same', name=task + '_aux_Conv3D_last')(up_tr1)
        aux_out = Activation('softmax', name=task + '_aux')(aux_res)
        out.append(aux_out)
    if ds:
        out = [out]
        # deep supervision#1
        deep_1 = UpSampling3D((2, 2, 2), name=task + '_d1_UpSampling3D_0')(up_tr2)
        res = Conv3D(chn, 1, padding='same', name=task + '_d1_Conv3D_last')(deep_1)
        d_out_1 = Activation('softmax', name=task + '_d1')(res)
        out.append(d_out_1)

        # deep supervision#2
        deep_2 = UpSampling3D((2, 2, 2), name=task + '_d2_UpSampling3D_0')(up_tr3)
        deep_2 = UpSampling3D((2, 2, 2), name=task + '_d2_UpSampling3D_1')(deep_2)
        res = Conv3D(chn, 1, padding='same', name=task + '_d2_Conv3D_last')(deep_2)
        d_out_2 = Activation('softmax', name=task + '_d2')(res)
        out.append(d_out_2)

    return out
Exemplo n.º 2
0
def load_cp_models(model_names, args):
    """
    load compiled models.
    """

    nch = 1
    nf = args.feature_number
    bn = args.batch_norm
    dr = args.dropout
    net_type = args.u_v
    attention = args.attention

    ## start model
    input_data = Input((None, None, None, nch), name='input')  # input data
    in_tr = intro(input_data, nf, bn)

    if "2_in" in args.lb_io:
        input_data2 = Input(
            (None, None, None, nch),
            name='input_2')  # input data 2 with a different scale
        in_tr2 = intro(input_data2, nf, bn, name='2')

        in_tr = concatenate([in_tr, in_tr2])
        input_data = [input_data, input_data2]

    # down_path
    dwn_tr1 = down_trans(in_tr, nf * 2, 2, bn, dr, ty=net_type, name='block1')
    dwn_tr2 = down_trans(dwn_tr1,
                         nf * 4,
                         2,
                         bn,
                         dr,
                         ty=net_type,
                         name='block2')
    dwn_tr3 = down_trans(dwn_tr2,
                         nf * 8,
                         2,
                         bn,
                         dr,
                         ty=net_type,
                         name='block3')
    dwn_tr4 = down_trans(dwn_tr3,
                         nf * 16,
                         2,
                         bn,
                         dr,
                         ty=net_type,
                         name='block4')

    #######################################################-----------------------#####################################
    # decoder for lobe segmentation. up_path for segmentation V-Net, with shot connections
    up_tr4_lobe = up_trans(dwn_tr4,
                           nf * 8,
                           2,
                           bn,
                           dr,
                           ty=net_type,
                           input2=dwn_tr3,
                           name='lobe_block5')
    up_tr3_lobe = up_trans(up_tr4_lobe,
                           nf * 4,
                           2,
                           bn,
                           dr,
                           ty=net_type,
                           input2=dwn_tr2,
                           name='lobe_block6')
    up_tr2_lobe = up_trans(up_tr3_lobe,
                           nf * 2,
                           2,
                           bn,
                           dr,
                           ty=net_type,
                           input2=dwn_tr1,
                           name='lobe_block7')
    up_tr1_lobe = up_trans(up_tr2_lobe,
                           nf * 1,
                           1,
                           bn,
                           dr,
                           ty=net_type,
                           input2=in_tr,
                           name='lobe_block8')
    # classification
    lobe_out_chn = 6
    res_lobe = Conv3D(lobe_out_chn, 1, padding='same',
                      name='lobe_Conv3D_last')(up_tr1_lobe)
    out_lobe = Activation('softmax', name='lobe_out_segmentation')(res_lobe)

    out_lobe = [out_lobe]  # convert to list to append other outputs
    if "2_out" in args.lb_io:
        res_lobe2 = Conv3D(lobe_out_chn,
                           1,
                           padding='same',
                           name='lobe_Conv3D_last2')(up_tr1_lobe)
        out_lobe2 = Activation('softmax',
                               name='lobe_out_segmentation2')(res_lobe2)
        out_lobe.append(out_lobe2)
    if args.ao_lb:
        # aux_output
        aux_res = Conv3D(2, 1, padding='same',
                         name='lobe_aux_Conv3D_last')(up_tr1_lobe)
        aux_out = Activation('softmax', name='lobe_aux')(aux_res)
        out_lobe.append(aux_out)
    if args.ds_lb:
        out_lobe = [out_lobe]
        # deep supervision#1
        deep_1 = UpSampling3D((2, 2, 2),
                              name='lobe_d1_UpSampling3D_0')(up_tr2_lobe)
        res = Conv3D(lobe_out_chn,
                     1,
                     padding='same',
                     name='lobe_d1_Conv3D_last')(deep_1)
        d_out_1 = Activation('softmax', name='lobe_d1')(res)
        out_lobe.append(d_out_1)

        # deep supervision#2
        deep_2 = UpSampling3D((2, 2, 2),
                              name='lobe_d2_UpSampling3D_0')(up_tr3_lobe)
        deep_2 = UpSampling3D((2, 2, 2), name='lobe_d2_UpSampling3D_1')(deep_2)
        res = Conv3D(lobe_out_chn,
                     1,
                     padding='same',
                     name='lobe_d2_Conv3D_last')(deep_2)
        d_out_2 = Activation('softmax', name='lobe_d2')(res)
        out_lobe.append(d_out_2)

        #######################################################-----------------------#####################################
    # decoder for lung segmentation. up_path for segmentation V-Net, with shot connections
    up_tr4_lung = up_trans(dwn_tr4,
                           nf * 8,
                           2,
                           bn,
                           dr,
                           ty=net_type,
                           input2=dwn_tr3,
                           name='lung_block5')
    up_tr3_lung = up_trans(up_tr4_lung,
                           nf * 4,
                           2,
                           bn,
                           dr,
                           ty=net_type,
                           input2=dwn_tr2,
                           name='lung_block6')
    up_tr2_lung = up_trans(up_tr3_lung,
                           nf * 2,
                           2,
                           bn,
                           dr,
                           ty=net_type,
                           input2=dwn_tr1,
                           name='lung_block7')
    up_tr1_lung = up_trans(up_tr2_lung,
                           nf * 1,
                           1,
                           bn,
                           dr,
                           ty=net_type,
                           input2=in_tr,
                           name='lung_block8')  # note filters number
    # classification
    lung_out_chn = 2
    if attention:
        res_lung = Conv3D(lobe_out_chn,
                          1,
                          padding='same',
                          name='lung_Conv3D_last')(up_tr1_lung)
        out_lung = Activation('softmax',
                              name='lung_out_segmentation')(res_lung)
    else:
        res_lung = Conv3D(lung_out_chn,
                          1,
                          padding='same',
                          name='lung_Conv3D_last')(up_tr1_lung)
        out_lung = Activation('softmax',
                              name='lung_out_segmentation')(res_lung)

    out_lung = [out_lung]  # convert to list to append other outputs
    if "2_out" in args.lu_io:
        res_lung2 = Conv3D(lung_out_chn,
                           1,
                           padding='same',
                           name='lung_Conv3D_last2')(up_tr1_lung)
        out_lung2 = Activation('softmax',
                               name='lung_out_segmentation2')(res_lung2)
        out_lung.append(out_lung2)
    if args.ao_lu:
        # aux_output
        aux_res = Conv3D(2, 1, padding='same',
                         name='lung_aux_Conv3D_last')(up_tr1_lung)
        aux_out = Activation('softmax', name='lung_aux')(aux_res)
        out_lung.append(aux_out)
    if args.ds_lu:
        out_lung = [out_lung]
        # deep supervision#1
        deep_1 = UpSampling3D((2, 2, 2),
                              name='lung_d1_UpSampling3D_0')(up_tr2_lung)
        res = Conv3D(lung_out_chn,
                     1,
                     padding='same',
                     name='lung_d1_Conv3D_last')(deep_1)
        d_out_1 = Activation('softmax', name='lung_d1')(res)
        out_lung.append(d_out_1)

        # deep supervision#2
        deep_2 = UpSampling3D((2, 2, 2),
                              name='lung_d2_UpSampling3D_0')(up_tr3_lung)
        deep_2 = UpSampling3D((2, 2, 2), name='lung_d2_UpSampling3D_1')(deep_2)
        res = Conv3D(lung_out_chn,
                     1,
                     padding='same',
                     name='lung_d2_Conv3D_last')(deep_2)
        d_out_2 = Activation('softmax', name='lung_d2')(res)
        out_lung.append(d_out_2)

    #######################################################-----------------------#####################################
    # decoder for vessel segmentation. up_path for segmentation V-Net, with shot connections
    up_tr4_vessel = up_trans(dwn_tr4,
                             nf * 8,
                             2,
                             bn,
                             dr,
                             ty=net_type,
                             input2=dwn_tr3,
                             name='vessel_block5')
    up_tr3_vessel = up_trans(up_tr4_vessel,
                             nf * 4,
                             2,
                             bn,
                             dr,
                             ty=net_type,
                             input2=dwn_tr2,
                             name='vessel_block6')
    up_tr2_vessel = up_trans(up_tr3_vessel,
                             nf * 2,
                             2,
                             bn,
                             dr,
                             ty=net_type,
                             input2=dwn_tr1,
                             name='vessel_block7')
    up_tr1_vessel = up_trans(up_tr2_vessel,
                             nf * 1,
                             1,
                             bn,
                             dr,
                             ty=net_type,
                             input2=in_tr,
                             name='vessel_block8')
    # classification
    vessel_out_chn = 2
    # out_vessel_attention = out_vessel_tmp[1] * out_lobe_tmp
    if attention:
        res_vessel = Conv3D(lobe_out_chn,
                            1,
                            padding='same',
                            name='vessel_Conv3D_last')(up_tr1_vessel)
        out_vessel = Activation('softmax',
                                name='vessel_out_segmentation')(res_vessel)
    else:
        res_vessel = Conv3D(vessel_out_chn,
                            1,
                            padding='same',
                            name='vessel_Conv3D_last')(up_tr1_vessel)
        out_vessel = Activation('softmax',
                                name='vessel_out_segmentation')(res_vessel)

    out_vessel = [out_vessel]  # convert to list to append other outputs
    # vessel_aux=0
    if "2_out" in args.vs_io:
        res_vessel2 = Conv3D(vessel_out_chn,
                             1,
                             padding='same',
                             name='vessel_Conv3D_last2')(up_tr1_vessel)
        out_vessel2 = Activation('softmax',
                                 name='vessel_out_segmentation2')(res_vessel2)
        out_vessel.append(out_vessel2)
    if args.ao_vs:
        # aux_output
        aux_res = Conv3D(2, 1, padding='same',
                         name='vessel_aux_Conv3D_last')(up_tr1_vessel)
        aux_out = Activation('softmax', name='vessel_aux')(aux_res)
        out_vessel.append(aux_out)
    if args.ds_vs:
        out_vessel = [out_vessel]
        # deep supervision#1
        deep_1 = UpSampling3D((2, 2, 2),
                              name='vessel_d1_UpSampling3D_0')(up_tr2_vessel)
        res = Conv3D(vessel_out_chn,
                     1,
                     padding='same',
                     name='vessel_d1_Conv3D_last')(deep_1)
        d_out_1 = Activation('softmax', name='vessel_d1')(res)
        out_vessel.append(d_out_1)

        # deep supervision#2
        deep_2 = UpSampling3D((2, 2, 2),
                              name='vessel_d2_UpSampling3D_0')(up_tr3_vessel)
        deep_2 = UpSampling3D((2, 2, 2),
                              name='vessel_d2_UpSampling3D_1')(deep_2)
        res = Conv3D(vessel_out_chn,
                     1,
                     padding='same',
                     name='vessel_d2_Conv3D_last')(deep_2)
        d_out_2 = Activation('softmax', name='vessel_d2')(res)
        out_vessel.append(d_out_2)

    #######################################################-----------------------#####################################
    # decoder for airway segmentation. up_path for segmentation V-Net, with shot connections
    up_tr4_airway = up_trans(dwn_tr4,
                             nf * 8,
                             2,
                             bn,
                             dr,
                             ty=net_type,
                             input2=dwn_tr3,
                             name='airway_block5')
    up_tr3_airway = up_trans(up_tr4_airway,
                             nf * 4,
                             2,
                             bn,
                             dr,
                             ty=net_type,
                             input2=dwn_tr2,
                             name='airway_block6')
    up_tr2_airway = up_trans(up_tr3_airway,
                             nf * 2,
                             2,
                             bn,
                             dr,
                             ty=net_type,
                             input2=dwn_tr1,
                             name='airway_block7')
    up_tr1_airway = up_trans(up_tr2_airway,
                             nf * 1,
                             1,
                             bn,
                             dr,
                             ty=net_type,
                             input2=in_tr,
                             name='airway_block8')
    # classification
    airway_out_chn = 2
    if attention:
        res_airway = Conv3D(lobe_out_chn,
                            1,
                            padding='same',
                            name='airway_Conv3D_last')(up_tr1_airway)
        out_airway = Activation('softmax',
                                name='airway_out_segmentation')(res_airway)
    else:
        res_airway = Conv3D(airway_out_chn,
                            1,
                            padding='same',
                            name='airway_Conv3D_last')(up_tr1_airway)
        out_airway = Activation('softmax',
                                name='airway_out_segmentation')(res_airway)

    out_airway = [out_airway]  # convert to list to append other outputs
    if "2_out" in args.aw_io:
        res_airway2 = Conv3D(airway_out_chn,
                             1,
                             padding='same',
                             name='airway_Conv3D_last2')(up_tr1_airway)
        out_airway2 = Activation('softmax',
                                 name='airway_out_segmentation2')(res_airway2)
        out_airway.append(out_airway2)
    if args.ao_aw:
        # aux_output
        aux_res = Conv3D(2, 1, padding='same',
                         name='airway_aux_Conv3D_last')(up_tr1_airway)
        aux_out = Activation('softmax', name='airway_aux')(aux_res)
        out_airway.append(aux_out)
    if args.ds_aw:
        out_airway = [out_airway]
        # deep supervision#1
        deep_1 = UpSampling3D((2, 2, 2),
                              name='airway_d1_UpSampling3D_0')(up_tr2_airway)
        res = Conv3D(airway_out_chn,
                     1,
                     padding='same',
                     name='airway_d1_Conv3D_last')(deep_1)
        d_out_1 = Activation('softmax', name='airway_d1')(res)
        out_airway.append(d_out_1)

        # deep supervision#2
        deep_2 = UpSampling3D((2, 2, 2),
                              name='airway_d2_UpSampling3D_0')(up_tr3_airway)
        deep_2 = UpSampling3D((2, 2, 2),
                              name='airway_d2_UpSampling3D_1')(deep_2)
        res = Conv3D(airway_out_chn,
                     1,
                     padding='same',
                     name='airway_d2_Conv3D_last')(deep_2)
        d_out_2 = Activation('softmax', name='airway_d2')(res)
        out_airway.append(d_out_2)

    # decoder for reconstruction
    up_tr4_rec = up_trans(dwn_tr4,
                          nf * 8,
                          2,
                          bn,
                          dr,
                          ty=net_type,
                          name='rec_block5')
    up_tr3_rec = up_trans(up_tr4_rec,
                          nf * 4,
                          2,
                          bn,
                          dr,
                          ty=net_type,
                          name='rec_block6')
    up_tr2_rec = up_trans(up_tr3_rec,
                          nf * 2,
                          2,
                          bn,
                          dr,
                          ty=net_type,
                          name='rec_block7')
    up_tr1_rec = up_trans(up_tr2_rec,
                          nf * 1,
                          1,
                          bn,
                          dr,
                          ty=net_type,
                          name='rec_block8')
    # classification
    rec_out_chn = 1
    if attention:
        out_recon = Conv3D(lobe_out_chn, 1, padding='same',
                           name='out_recon')(up_tr1_rec)
    else:
        out_recon = Conv3D(rec_out_chn, 1, padding='same',
                           name='out_recon')(up_tr1_rec)

    if "2_out" in args.rc_io:
        out_recon2 = Conv3D(rec_out_chn, 1, padding='same',
                            name='out_recon2')(up_tr1_rec)
        out_recon = [out_recon, out_recon2]

    # out_rec = Activation ('softmax', name='rec_out_segmentation') (res_rec) # no activation for reconstruction

    out_itgt_vessel_recon = out_vessel + [out_recon]
    out_itgt_airway_recon = out_airway + [out_recon]
    out_itgt_lobe_recon = out_lobe + [out_recon]
    out_itgt_lung_recon = out_lung + [out_recon]

    metrics_seg_6_classes = [
        dice_coef_mean, dice_0, dice_1, dice_2, dice_3, dice_4, dice_5
    ]
    metrics_seg_2_classes = [dice_coef_mean, dice_0, dice_1]

    ###################----------------------------------#########################################
    # compile lobe models
    metrics_lobe = {'lobe_out_segmentation': metrics_seg_6_classes}
    if "2_out" in args.lb_io:
        metrics_lobe['lobe_out_segmentation2'] = metrics_seg_6_classes
    if args.ao_lb:
        metrics_lobe['lobe_aux'] = metrics_seg_2_classes
    if args.ds_lb == 2:
        metrics_lobe['lobe_d1'] = metrics_seg_6_classes
        metrics_lobe['lobe_d2'] = metrics_seg_6_classes

    loss, loss_weights, loss_itgt_recon, loss_itgt_recon_weights, optim = get_loss_weights_optim(
        args.ao_lb, args.ds_lb, args.lr_lb, args.lb_io)
    net_only_lobe = Model(input_data, out_lobe, name='net_only_lobe')
    net_only_lobe.compile(optimizer=optim,
                          loss=loss,
                          loss_weights=loss_weights,
                          metrics=metrics_lobe)

    net_itgt_lobe_recon = Model(input_data,
                                out_itgt_lobe_recon,
                                name='net_itgt_lobe_recon')
    # net_itgt_lobe_recon.compile(optimizer=optim,
    #                             loss=loss_itgt_recon,
    #                             loss_weights=loss_itgt_recon_weights,
    #                             metrics=metrics_lobe.update({'out_recon': 'mse'}))

    ###################----------------------------------#########################################
    # compile vessel models
    if attention:
        metrics_vessel = {'vessel_out_segmentation': metrics_seg_6_classes}
    else:
        metrics_vessel = {'vessel_out_segmentation': metrics_seg_2_classes}
    if "2_out" in args.vs_io:
        metrics_vessel['vessel_out_segmentation2'] = metrics_seg_2_classes
    if args.ao_vs:
        metrics_vessel['vessel_aux'] = metrics_seg_2_classes
    if args.ds_vs == 2:
        metrics_vessel['vessel_d1'] = metrics_seg_2_classes
        metrics_vessel['vessel_d2'] = metrics_seg_2_classes

    loss, loss_weights, loss_itgt_recon, loss_itgt_recon_weights, optim = get_loss_weights_optim(
        args.ao_vs, args.ds_vs, args.lr_vs, args.vs_io)
    net_only_vessel = Model(input_data, out_vessel, name='net_only_vessel')
    net_only_vessel.compile(optimizer=optim,
                            loss=loss,
                            loss_weights=loss_weights,
                            metrics=metrics_vessel)

    net_itgt_vessel_recon = Model(input_data,
                                  out_itgt_vessel_recon,
                                  name='net_itgt_vessel_recon')
    # net_itgt_vessel_recon.compile(optimizer=optim,
    #                               loss=loss_itgt_recon,
    #                               loss_weights=loss_itgt_recon_weights,
    #                               metrics=metrics_vessel.update({'out_recon': 'mse'}))

    ###################----------------------------------#########################################
    # compile airway models
    if attention:
        metrics_airway = {'airway_out_segmentation': metrics_seg_2_classes}
    else:
        metrics_airway = {'airway_out_segmentation': metrics_seg_2_classes}

    if "2_out" in args.aw_io:
        metrics_airway['airway_out_segmentation2'] = metrics_seg_2_classes
    if args.ao_aw:
        metrics_airway['airway_aux'] = metrics_seg_2_classes
    if args.ds_aw == 2:
        metrics_airway['airway_d1'] = metrics_seg_2_classes
        metrics_airway['airway_d2'] = metrics_seg_2_classes

    loss, loss_weights, loss_itgt_recon, loss_itgt_recon_weights, optim = get_loss_weights_optim(
        args.ao_aw, args.ds_aw, args.lr_aw, args.aw_io)
    net_only_airway = Model(input_data, out_airway, name='net_only_airway')
    net_only_airway.compile(optimizer=optim,
                            loss=loss,
                            loss_weights=loss_weights,
                            metrics=metrics_airway)

    net_itgt_airway_recon = Model(input_data,
                                  out_itgt_airway_recon,
                                  name='net_itgt_airway_recon')
    # net_itgt_airway_recon.compile(optimizer=optim,
    #                               loss=loss_itgt_recon,
    #                               loss_weights=loss_itgt_recon_weights,
    #                               metrics=metrics_airway.update({'out_recon': 'mse'}))
    ###################----------------------------------#########################################
    # compile lung models
    if attention:
        metrics_lung = {'lung_out_segmentation': metrics_seg_6_classes}
    else:
        metrics_lung = {'lung_out_segmentation': metrics_seg_2_classes}
    if "2_out" in args.lu_io:
        metrics_lung['lung_out_segmentation2'] = metrics_seg_2_classes
    if args.ao_lu:
        metrics_lung['lung_aux'] = metrics_seg_2_classes
    if args.ds_lu == 2:
        metrics_lung['lung_d1'] = metrics_seg_2_classes
        metrics_lung['lung_d2'] = metrics_seg_2_classes

    loss, loss_weights, loss_itgt_recon, loss_itgt_recon_weights, optim = get_loss_weights_optim(
        args.ao_lu, args.ds_lu, args.lr_lu, args.lu_io)
    net_only_lung = Model(input_data, out_lung, name='net_only_lung')
    net_only_lung.compile(optimizer=optim,
                          loss=loss,
                          loss_weights=loss_weights,
                          metrics=metrics_lung)

    net_itgt_lung_recon = Model(input_data,
                                out_itgt_lung_recon,
                                name='net_itgt_lung_recon')
    # net_itgt_lung_recon.compile(optimizer=optim,
    #                             loss=loss_itgt_recon,
    #                             loss_weights=loss_itgt_recon_weights,
    #                             metrics=metrics_lung.update({'out_recon': 'mse'}))

    # configeration and compilization for network in recon task
    optim, loss_weights, _, __, optim = get_loss_weights_optim(args.ao_rc,
                                                               args.ds_rc,
                                                               args.lr_rc,
                                                               args.rc_io,
                                                               task='no_label')
    net_no_label = Model(input_data, out_recon, name='net_no_label')
    net_no_label.compile(optimizer=optim,
                         loss='mse',
                         loss_weights=loss_weights,
                         metrics=['mse'])

    models_dict = {
        "net_itgt_lu_rc": net_itgt_lung_recon,
        "net_itgt_aw_rc": net_itgt_airway_recon,
        "net_itgt_lb_rc": net_itgt_lobe_recon,
        "net_itgt_vs_rc": net_itgt_vessel_recon,
        "net_no_label": net_no_label,
        "net_only_lobe": net_only_lobe,
        "net_only_vessel": net_only_vessel,
        "net_only_lung": net_only_lung,
        "net_only_airway": net_only_airway,
    }

    return list(map(models_dict.get, model_names))