Ejemplo n.º 1
0
def xception_se_lstm_singledur(input_shape = (shape_r, shape_c, 3),
                     conv_filters=256,
                     lstm_filters=512,
                     verbose=True,
                     print_shapes=True,
                     n_outs=1,
                     ups=8,
                     freeze_enc=False,
                     return_sequences=False):
    inp = Input(shape = input_shape)

    ### ENCODER ###
    xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
    if print_shapes: print('xception output shapes:',xception.output.shape)
    if freeze_enc:
        for layer in xception.layers:
	        layer.trainable = False

    ### LSTM over SE representation ###
    x = se_lstm_block(xception.output, nb_timestep, lstm_filters=lstm_filters, return_sequences=return_sequences)

    ### DECODER ###
    outs_dec = decoder_block(x, dil_rate=(2,2), print_shapes=print_shapes, dec_filt=conv_filters)

    outs_final = [outs_dec]*n_outs
    m = Model(inp, outs_final)
    if verbose:
        m.summary()
    return m
def xception_decoder_timedist(input_shape=(shape_r, shape_c, 3),
                              verbose=True,
                              print_shapes=True,
                              n_outs=1,
                              ups=8,
                              dil_rate=(1, 1)):
    inp = Input(shape=input_shape)

    ### ENCODER ###
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception:', xception.output.shape)

    x = Lambda(
        lambda x: K.repeat_elements(
            K.expand_dims(x, axis=1), nb_timestep, axis=1), lambda s:
        (s[0], nb_timestep) + s[1:])(xception.output)

    ## DECODER ##
    outs_dec = decoder_block_timedist(x,
                                      dil_rate=dil_rate,
                                      print_shapes=print_shapes,
                                      dec_filt=512)

    outs_final = [outs_dec] * n_outs

    # Building model
    m = Model(inp, outs_final)
    if verbose:
        m.summary()
    return m
Ejemplo n.º 3
0
def sam_xception_new(input_shape = (shape_r, shape_c, 3), conv_filters=512, lstm_filters=512, att_filters=512,
                   verbose=True, print_shapes=True, n_outs=1, ups=8, nb_gaussian=nb_gaussian):
    '''SAM with a custom Xception as encoder.'''

    inp = Input(shape=input_shape)

    from xception_custom import Xception
    from keras.applications import keras_modules_injection
    @keras_modules_injection
    def Xception_wrapper(*args, **kwargs):
        return Xception(*args, **kwargs)

    inp = Input(shape = input_shape)
    dcn = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
    if print_shapes: print('xception:',dcn.output.shape)

    conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output)
    if print_shapes:
        print('Shape after first conv after dcn_resnet:',conv_feat.shape)

    # Attentive ConvLSTM
    att_convlstm = Lambda(repeat, repeat_shape)(conv_feat)
    att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3),
                            attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm)

    # Learned Prior (1)
    priors1 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm)
    concat1 = Concatenate(axis=-1)([att_convlstm, priors1])
    dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat1)

    # Learned Prior (2)
    priors2 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm)
    concat2 = Concatenate(axis=-1)([dil_conv1, priors2])
    dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat2)

    # Final conv to get to a heatmap
    outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2)
    if print_shapes:
        print('Shape after 1x1 conv:',outs.shape)

    # Upsampling back to input shape
    outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs)
    if print_shapes:
        print('shape after upsampling',outs_up.shape)


    outs_final = [outs_up]*n_outs


    # Building model
    m = Model(inp, outs_final)
    if verbose:
        m.summary()

    return m
def xception_cl(input_shape=(None, None, 3),
                verbose=True,
                print_shapes=True,
                n_outs=1,
                ups=8,
                freeze_enc=False,
                dil_rate=(2, 2),
                freeze_cl=True,
                append_classif=True,
                num_classes=5):
    """Xception with classification capabilities"""
    inp = Input(shape=input_shape)

    ### ENCODER ###
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception output shapes:', xception.output.shape)
    if freeze_enc:
        for layer in xception.layers:
            layer.trainable = False

    ### CLASSIFIER ###
    cl = GlobalAveragePooling2D(name='gap_cl')(xception.output)
    cl = Dense(512, name='dense_cl')(cl)
    cl = Dropout(0.3, name='dropout_cl')(cl)
    cl = Dense(num_classes, activation='softmax', name='dense_cl_out')(cl)

    ## DECODER ##
    outs_dec = decoder_block(xception.output,
                             dil_rate=dil_rate,
                             print_shapes=print_shapes,
                             dec_filt=512,
                             prefix='decoder')

    outs_final = [outs_dec] * n_outs

    if append_classif:
        outs_final.append(cl)

    # Building model
    m = Model(
        inp, outs_final)  # Last element of outs_final is classification vector
    if verbose:
        m.summary()

    if freeze_cl:
        print('Freezing classification dense layers')
        m.get_layer('dense_cl').trainable = False
        m.get_layer('dense_cl_out').trainable = False

    return m
def xception_se_lstm_nodecoder(input_shape=(shape_r, shape_c, 3),
                               conv_filters=512,
                               lstm_filters=512,
                               verbose=True,
                               print_shapes=True,
                               n_outs=1,
                               ups=8,
                               freeze_enc=False,
                               return_sequences=True):
    inp = Input(shape=input_shape)

    ### ENCODER ###
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception output shapes:', xception.output.shape)
    if freeze_enc:
        for layer in xception.layers:
            layer.trainable = False

    ### LSTM over SE representation ###
    x = se_lstm_block_timedist(xception.output,
                               nb_timestep,
                               lstm_filters=lstm_filters,
                               return_sequences=return_sequences)

    ## DECODER ###
    x = TimeDistributed(Dropout(0.3))(x)
    x = TimeDistributed(
        Conv2D(filters=conv_filters,
               kernel_size=3,
               padding='same',
               activation='relu'))(x)
    x = TimeDistributed(Dropout(0.3))(x)
    x = TimeDistributed(
        Conv2D(1, kernel_size=1, padding='same', activation='relu'))(x)
    outs_dec = TimeDistributed(
        UpSampling2D(size=(ups, ups), interpolation='bilinear'))(x)

    outs_final = [outs_dec] * n_outs
    m = Model(inp, outs_final)
    if verbose:
        m.summary()
    return m
def xception_3stream(input_shape=(shape_r, shape_c, 3),
                     conv_filters=512,
                     verbose=True,
                     print_shapes=True,
                     n_streams=3,
                     ups=16):
    inp = Input(shape=input_shape)
    ### ENCODER ###
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception:', xception.output.shape)

    #     x = Conv2D(64, kernel_size=3, padding='same', activation='relu')(inp)

    def out_stream(x):
        x = Conv2D(filters=conv_filters,
                   kernel_size=3,
                   padding='same',
                   activation='relu')(x)
        x = Conv2D(1, kernel_size=1, padding='same', activation='relu')(x)
        x = UpSampling2D(size=(ups, ups), interpolation='bilinear')(x)
        x = Lambda(
            lambda y: K.repeat_elements(K.expand_dims(y, axis=1), 2, axis=1),
            output_shape=lambda s: (s[0], 2) + s[1:])(x)
        if print_shapes: print('Shape after ups:', x.shape)
        return x

    outs = [out_stream(xception.output) for _ in range(n_streams)]

    # print('len(outs)',len(outs))
    # print('outs[0].shape',outs[0].shape)

    m = Model(inp, outs)

    if verbose:
        m.summary()
    return m
Ejemplo n.º 7
0
def xception_decoder(input_shape = (shape_r, shape_c, 3),
                     verbose=True,
                     print_shapes=True,
                     n_outs=1,
                     ups=8,
                    dil_rate = (2,2)):

    inp = Input(shape=input_shape)

    ### ENCODER ###
    xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None)
    if print_shapes: print('xception:',xception.output.shape)

    ## DECODER ##
    outs_dec = decoder_block(xception.output, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=512)

    outs_final = [outs_dec]*n_outs

    # Building model
    m = Model(inp, outs_final)
    if verbose:
        m.summary()
    return m
def xception_lstm_md(input_shape=(shape_r, shape_c, 3),
                     conv_filters=256,
                     lstm_filters=256,
                     att_filters=256,
                     verbose=True,
                     print_shapes=True,
                     n_outs=1,
                     ups=8):

    inp = Input(shape=input_shape)
    ### ENCODER ###
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception:', xception.output.shape)

    conv_feat = Conv2D(conv_filters, 3, padding='same',
                       activation='relu')(xception.output)
    if print_shapes:
        print('Shape after first conv after xception:', conv_feat.shape)

    # Attentive ConvLSTM
    att_convlstm = Lambda(
        lambda x: K.repeat_elements(
            K.expand_dims(x, axis=1), nb_timestep, axis=1), lambda s:
        (s[0], nb_timestep) + s[1:])(conv_feat)
    att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters,
                                       attentive_filters=att_filters,
                                       kernel_size=(3, 3),
                                       attentive_kernel_size=(3, 3),
                                       padding='same',
                                       return_sequences=True)(att_convlstm)

    if print_shapes: print('(att_convlstm.shape', att_convlstm.shape)

    # Dilated convolutions (priors would go here)
    dil_conv1 = TimeDistributed(
        Conv2D(conv_filters,
               5,
               padding='same',
               activation='relu',
               dilation_rate=(4, 4)))(att_convlstm)
    dil_conv2 = TimeDistributed(
        Conv2D(conv_filters,
               5,
               padding='same',
               activation='relu',
               dilation_rate=(4, 4)))(dil_conv1)

    # Final conv to get to a heatmap
    outs = TimeDistributed(
        Conv2D(1, kernel_size=1, padding='same', activation='relu'))(dil_conv2)
    if print_shapes: print('Shape after 1x1 conv:', outs.shape)

    # Upsampling back to input shape
    outs_up = TimeDistributed(
        UpSampling2D(size=(ups, ups), interpolation='bilinear'))(outs)
    if print_shapes: print('shape after upsampling', outs_up.shape)

    outs_final = [outs_up] * n_outs

    # Building model
    m = Model(inp, outs_final)
    if verbose:
        m.summary()

    return m
def sam_xception_timedist(input_shape=(shape_r, shape_c, 3),
                          conv_filters=512,
                          lstm_filters=512,
                          att_filters=512,
                          verbose=True,
                          print_shapes=True,
                          n_outs=1,
                          ups=8,
                          nb_gaussian=nb_gaussian):
    '''SAM-ResNet ported from the original code.'''

    inp = Input(shape=input_shape)

    # Input CNN
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception:', xception.output.shape)

    conv_feat = Conv2D(conv_filters, 3, padding='same',
                       activation='relu')(xception.output)
    if print_shapes:
        print('Shape after first conv after dcn_resnet:', conv_feat.shape)

    # Attentive ConvLSTM
    att_convlstm = Lambda(repeat, repeat_shape)(conv_feat)
    att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters,
                                       attentive_filters=att_filters,
                                       kernel_size=(3, 3),
                                       attentive_kernel_size=(3, 3),
                                       padding='same',
                                       return_sequences=True)(att_convlstm)

    # Learned Prior (1)
    priors1 = TimeDistributed(
        LearningPrior(nb_gaussian=nb_gaussian))(att_convlstm)
    concat1 = Concatenate(axis=-1)([att_convlstm, priors1])
    dil_conv1 = TimeDistributed(
        Conv2D(conv_filters,
               5,
               padding='same',
               activation='relu',
               dilation_rate=(4, 4)))(concat1)

    # Learned Prior (2)
    priors2 = TimeDistributed(
        LearningPrior(nb_gaussian=nb_gaussian))(att_convlstm)
    concat2 = Concatenate(axis=-1)([dil_conv1, priors2])
    dil_conv2 = TimeDistributed(
        Conv2D(conv_filters,
               5,
               padding='same',
               activation='relu',
               dilation_rate=(4, 4)))(concat2)

    # Final conv to get to a heatmap
    outs = TimeDistributed(
        Conv2D(1, kernel_size=1, padding='same', activation='relu'))(dil_conv2)
    if print_shapes:
        print('Shape after 1x1 conv:', outs.shape)

    # Upsampling back to input shape
    outs_up = TimeDistributed(
        UpSampling2D(size=(ups, ups), interpolation='bilinear'))(outs)
    if print_shapes:
        print('shape after upsampling', outs_up.shape)

    outs_final = [outs_up] * n_outs

    # Building model
    m = Model(inp, outs_final)
    if verbose:
        m.summary()

    return m
def xception_cl_fus(input_shape=(None, None, 3),
                    verbose=True,
                    print_shapes=True,
                    n_outs=1,
                    ups=8,
                    dil_rate=(2, 2),
                    freeze_enc=False,
                    freeze_cl=True,
                    internal_filts=256,
                    num_classes=5,
                    dp=0.3):
    """Xception with classification capabilities that fuses representations from both tasks"""
    inp = Input(shape=input_shape)

    ### ENCODER ###
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception output shapes:', xception.output.shape)
    if freeze_enc:
        for layer in xception.layers:
            layer.trainable = False

    ### GLOBAL FEATURES ###
    g_n = global_net(xception.output, nfilts=internal_filts, dp=dp)
    if print_shapes: print('g_n shapes:', g_n.shape)

    ### CLASSIFIER ###
    # We potentially need another layer here
    out_classif = Dense(num_classes, activation='softmax',
                        name='out_classif')(g_n)

    ### ASPP (MID LEVEL FEATURES) ###
    aspp_out = app(xception.output, internal_filts)
    if print_shapes: print('aspp out shapes:', aspp_out.shape)

    ### FUSION ###
    dense_f = Dense(internal_filts, name='dense_fusion')(g_n)
    if print_shapes: print('dense_f shapes:', dense_f.shape)
    reshap = Lambda(
        lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(
            K.expand_dims(x, axis=1), K.int_shape(aspp_out)[2], axis=1),
                                                  axis=1),
                                    K.int_shape(aspp_out)[1],
                                    axis=1), lambda s:
        (s[0], K.int_shape(aspp_out)[1], K.int_shape(aspp_out)[2], s[1]))(
            dense_f)
    if print_shapes: print('after lambda shapes:', reshap.shape)

    conc = Concatenate()([aspp_out, reshap])

    ### Projection ###
    x = Conv2D(internal_filts, (1, 1),
               padding='same',
               use_bias=False,
               name='concat_projection')(conc)
    x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x)
    x = Activation('relu')(x)
    x = Dropout(dp)(x)

    ### DECODER ###
    outs_dec = decoder_block(x,
                             dil_rate=dil_rate,
                             print_shapes=print_shapes,
                             dec_filt=internal_filts,
                             dp=dp)

    outs_final = [outs_dec] * n_outs
    outs_final.append(out_classif)

    # Building model
    m = Model(
        inp, outs_final)  # Last element of outs_final is classification vector

    if freeze_cl:
        m.get_layer('out_classif').trainable = False
        # for l in g_n.layers:
        #     l.trainable=False

    if verbose:
        m.summary()

    return m
def umsi(input_shape=(None, None, 3),
         verbose=True,
         print_shapes=True,
         n_outs=1,
         ups=8,
         dil_rate=(2, 2),
         freeze_enc=False,
         freeze_cl=True,
         internal_filts=256,
         num_classes=6,
         dp=0.3,
         lambda_layer_for_save=False):

    inp = Input(shape=input_shape)

    ### ENCODER ###
    xception = Xception_wrapper(include_top=False,
                                weights='imagenet',
                                input_tensor=inp,
                                pooling=None)
    if print_shapes: print('xception output shapes:', xception.output.shape)
    if freeze_enc:
        for layer in xception.layers:
            layer.trainable = False


#     xception.summary()

    skip_layers = ['block3_sepconv2_bn', 'block1_conv1_act']
    # sizes: 119x159x32, 59x79x256
    skip_feature_maps = [xception.get_layer(n).output for n in skip_layers]

    ### GLOBAL FEATURES ###
    g_n = global_net(xception.output, nfilts=internal_filts, dp=dp)
    if print_shapes: print('g_n shapes:', g_n.shape)

    ### CLASSIFIER ###
    # We potentially need another layer here
    out_classif = Dense(num_classes, activation='softmax',
                        name='out_classif')(g_n)

    ### ASPP (MID LEVEL FEATURES) ###
    aspp_out = aspp(xception.output, internal_filts)
    if print_shapes: print('aspp out shapes:', aspp_out.shape)

    ### FUSION ###
    dense_f = Dense(internal_filts, name='dense_fusion')(g_n)
    if print_shapes: print('dense_f shapes:', dense_f.shape)

    if not lambda_layer_for_save:
        reshap = Lambda(
            lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(
                K.expand_dims(x, axis=1), K.int_shape(aspp_out)[2], axis=1),
                                                      axis=1),
                                        K.int_shape(aspp_out)[1],
                                        axis=1), lambda s:
            (s[0], K.int_shape(aspp_out)[1], K.int_shape(aspp_out)[2], s[1]))(
                dense_f)
    else:  # Use this lambda layer if you want to be able to use model.save() (set lambda_layer_for_save to True)
        print("Using lambda layer adapted to model.save()")
        reshap = Lambda(
            lambda x: K.repeat_elements(K.expand_dims(K.repeat_elements(
                K.expand_dims(x, axis=1), 40, axis=1),
                                                      axis=1),
                                        30,
                                        axis=1), lambda s:
            (s[0], 30, 40, s[1]))(dense_f)
        # reshap = FusionReshape()(dense_f)

    if print_shapes: print('after lambda shapes:', reshap.shape)

    conc = Concatenate()([aspp_out, reshap])

    ### Projection ###
    x = Conv2D(internal_filts, (1, 1),
               padding='same',
               use_bias=False,
               name='concat_projection')(conc)
    x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x)
    x = Activation('relu')(x)
    x = Dropout(dp)(x)

    ### DECODER ###
    #     outs_dec = decoder_block(x, dil_rate=dil_rate, print_shapes=print_shapes, dec_filt=internal_filts, dp=dp)

    outs_dec = decoder_with_skip(x,
                                 skip_feature_maps,
                                 print_shapes=print_shapes,
                                 dec_filt=internal_filts,
                                 dp=dp)

    outs_final = [outs_dec] * n_outs
    outs_final.append(out_classif)

    # Building model
    m = Model(
        inp, outs_final,
        name='umsi')  # Last element of outs_final is classification vector

    if freeze_cl:
        m.get_layer('out_classif').trainable = False
        # for l in g_n.layers:
        #     l.trainable=False

    if verbose:
        m.summary()

    return m