def sam_resnet_new(input_shape = (shape_r, shape_c, 3), conv_filters=512, lstm_filters=512, att_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8, nb_gaussian=nb_gaussian): '''SAM-ResNet ported from the original code.''' inp = Input(shape=input_shape) dcn = dcn_resnet(input_tensor=inp) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output) if print_shapes: print('Shape after first conv after dcn_resnet:',conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda(repeat, repeat_shape)(conv_feat) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm) # Learned Prior (1) priors1 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat1 = Concatenate(axis=-1)([att_convlstm, priors1]) dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat1) # Learned Prior (2) priors2 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat2 = Concatenate(axis=-1)([dil_conv1, priors2]) dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat2) # Final conv to get to a heatmap outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2) if print_shapes: print('Shape after 1x1 conv:',outs.shape) # Upsampling back to input shape outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs) if print_shapes: print('shape after upsampling',outs_up.shape) outs_final = [outs_up]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m
def sam_xception_new(input_shape = (shape_r, shape_c, 3), conv_filters=512, lstm_filters=512, att_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8, nb_gaussian=nb_gaussian): '''SAM with a custom Xception as encoder.''' inp = Input(shape=input_shape) from xception_custom import Xception from keras.applications import keras_modules_injection @keras_modules_injection def Xception_wrapper(*args, **kwargs): return Xception(*args, **kwargs) inp = Input(shape = input_shape) dcn = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception:',dcn.output.shape) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output) if print_shapes: print('Shape after first conv after dcn_resnet:',conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda(repeat, repeat_shape)(conv_feat) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm) # Learned Prior (1) priors1 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat1 = Concatenate(axis=-1)([att_convlstm, priors1]) dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat1) # Learned Prior (2) priors2 = LearningPrior(nb_gaussian=nb_gaussian)(att_convlstm) concat2 = Concatenate(axis=-1)([dil_conv1, priors2]) dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(concat2) # Final conv to get to a heatmap outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2) if print_shapes: print('Shape after 1x1 conv:',outs.shape) # Upsampling back to input shape outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs) if print_shapes: print('shape after upsampling',outs_up.shape) outs_final = [outs_up]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m
def lstm_timedist(input_shape=(224, 224, 3), conv_filters=512, lstm_filters=512, att_filters=512, verbose=True, print_shapes=True, n_outs=1, nb_timestep=3, ups=upsampling_factor): '''LSTM with outputs at each timestep of the LSTM. Needs to be trained with the singlestream loss. DOESNT WORK! Imagenet pretrained models are essential for this task''' inp = Input(shape=input_shape) x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', kernel_initializer='he_normal', name='conv1')(inp) x = BatchNormalization(axis=3, name='bn_conv1')(x) x = Activation('relu')(x) x = MaxPooling2D((2, 2), strides=(2, 2))(x) if print_shapes: print('x.shape after input block', x.shape) # Attentive ConvLSTM att_convlstm = Lambda( lambda x: K.repeat_elements( K.expand_dims(x, axis=1), nb_timestep, axis=1), lambda s: (s[0], nb_timestep) + s[1:])(x) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3, 3), attentive_kernel_size=(3, 3), padding='same', return_sequences=True)(att_convlstm) if print_shapes: print('(att_convlstm.shape', att_convlstm.shape) outs_final = TimeDistributed( Conv2D(1, kernel_size=1, padding='same', activation='relu'))(att_convlstm) outs_final = TimeDistributed(UpSampling2D((ups, ups)))(outs_final) if print_shapes: print('outs_final shape:', outs_final.shape) outs_final = [outs_final] * n_outs m = Model(inp, outs_final) if verbose: m.summary() return m
def sam_resnet_nopriors(input_shape = (224, 224, 3), conv_filters=128, lstm_filters=512, att_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8): '''Sam ResNet with no priors.''' inp = Input(shape=input_shape) dcn = dcn_resnet(input_tensor=inp) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output) if print_shapes: print('Shape after first conv after dcn_resnet:',conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda(repeat, repeat_shape)(conv_feat) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(att_convlstm) # Dilated convolutions (priors would go here) dil_conv1 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(att_convlstm) dil_conv2 = Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4))(dil_conv1) # Final conv to get to a heatmap outs = Conv2D(1, kernel_size=1, padding='same', activation='relu')(dil_conv2) if print_shapes: print('Shape after 1x1 conv:',outs.shape) # Upsampling back to input shape outs_up = UpSampling2D(size=(ups,ups), interpolation='bilinear')(outs) if print_shapes: print('shape after upsampling',outs_up.shape) outs_final = [outs_up]*n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m
def sam_simple(input_shape = (224, 224, 3), in_conv_filters=512, verbose=True, print_shapes=True, n_outs=1, ups=8): '''Simple network that uses an attentive convlstm and a few convolutions.''' inp = Input(shape=input_shape) x = Conv2D(filters=in_conv_filters, kernel_size=(3,3), strides=(2, 2), padding='same', data_format=None, dilation_rate=(1,1))(inp) if print_shapes: print('after first conv') x = MaxPooling2D(pool_size=(4,4))(x) if print_shapes: print('after maxpool',x.shape) x = Lambda(repeat, repeat_shape)(x) if print_shapes: print('after repeat',x.shape) x = AttentiveConvLSTM2D(filters=512, attentive_filters=512, kernel_size=(3,3), attentive_kernel_size=(3,3), padding='same', return_sequences=False)(x) if print_shapes: print('after ACLSTM',x.shape) x = UpSampling2D(size=(ups,ups), interpolation='bilinear')(x) outs_up = Conv2D(filters=1, kernel_size=(3,3), strides=(1, 1), padding='same', data_format=None, dilation_rate=(1,1))(x) if print_shapes: print('output shape',outs_up.shape) outs_final = [outs_up]*n_outs att_convlstm = Model(inputs=inp, outputs=outs_final) if verbose: att_convlstm.summary() return att_convlstm
def xception_lstm_md(input_shape=(shape_r, shape_c, 3), conv_filters=256, lstm_filters=256, att_filters=256, verbose=True, print_shapes=True, n_outs=1, ups=8): inp = Input(shape=input_shape) ### ENCODER ### xception = Xception_wrapper(include_top=False, weights='imagenet', input_tensor=inp, pooling=None) if print_shapes: print('xception:', xception.output.shape) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(xception.output) if print_shapes: print('Shape after first conv after xception:', conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda( lambda x: K.repeat_elements( K.expand_dims(x, axis=1), nb_timestep, axis=1), lambda s: (s[0], nb_timestep) + s[1:])(conv_feat) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3, 3), attentive_kernel_size=(3, 3), padding='same', return_sequences=True)(att_convlstm) if print_shapes: print('(att_convlstm.shape', att_convlstm.shape) # Dilated convolutions (priors would go here) dil_conv1 = TimeDistributed( Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4)))(att_convlstm) dil_conv2 = TimeDistributed( Conv2D(conv_filters, 5, padding='same', activation='relu', dilation_rate=(4, 4)))(dil_conv1) # Final conv to get to a heatmap outs = TimeDistributed( Conv2D(1, kernel_size=1, padding='same', activation='relu'))(dil_conv2) if print_shapes: print('Shape after 1x1 conv:', outs.shape) # Upsampling back to input shape outs_up = TimeDistributed( UpSampling2D(size=(ups, ups), interpolation='bilinear'))(outs) if print_shapes: print('shape after upsampling', outs_up.shape) outs_final = [outs_up] * n_outs # Building model m = Model(inp, outs_final) if verbose: m.summary() return m
def sam_resnet_3d(input_shape=(224, 224, 3), filt_3d=128, conv_filters=128, lstm_filters=128, att_filters=128, verbose=True, print_shapes=True, n_outs=1, nb_timestep=3, ups=upsampling_factor): inp = Input(shape=input_shape) # Input CNN dcn = dcn_resnet(input_tensor=inp) conv_feat = Conv2D(conv_filters, 3, padding='same', activation='relu')(dcn.output) if print_shapes: print('Shape after first conv after dcn_resnet:', conv_feat.shape) # Attentive ConvLSTM att_convlstm = Lambda( lambda x: K.repeat_elements( K.expand_dims(x, axis=1), nb_timestep, axis=1), lambda s: (s[0], nb_timestep) + s[1:])(x) att_convlstm = AttentiveConvLSTM2D(filters=lstm_filters, attentive_filters=att_filters, kernel_size=(3, 3), attentive_kernel_size=(3, 3), padding='same', return_sequences=True)(att_convlstm) if print_shapes: print('att_convlstm output shape', att_convlstm.shape) # Output flow x = Conv3D(filt_3d, (3, 3, 3), strides=(1, 1, 1), padding='same', dilation_rate=(4, 4, 1), activation=None, kernel_initializer='he_normal')(att_convlstm) x = BatchNormalization(axis=4)(x) x = Activation('relu')(x) x = Conv3D(filt_3d, (3, 3, 1), strides=(1, 1, 1), padding='same', dilation_rate=(4, 4, 1), activation=None, kernel_initializer='he_normal')(x) x = BatchNormalization(axis=4)(x) x = Activation('relu')(x) x = Conv3D(1, (1, 1, 1), strides=(1, 1, 1), padding='same', activation='relu', kernel_initializer='he_normal')(x) out_final = TimeDistributed( UpSampling2D(size=(ups, ups), interpolation='bilinear'))(x) if print_shapes: print('outs_final shape:', outs_final.shape) outs_final = [outs_final] * n_outs m = Model(inp, outs_final) if verbose: m.summary() return m